구간 곱 구하기

백준 11505

GOLD I


  1. 문제 내용
  2. 해결 방안
  3. 풀이 코드

문제 내용

problem.png

문제 링크

해결 방안

2357번(최솟값과 최댓값) 문제에 이어 세그먼트 트리에 관한 문제이다. 주어진 정수들에 대해, 숫자를 바꾸거나 주어진 범위의 값들의 곱을 구하는 여러 쿼리를 처리하는 문제이다. 누적 합 문제처럼 미리 값을 저장하는 방법을 쓰기에는 곱셈이라 값이 매우 커질 수 있고, 정수가 계속 바뀌기 때문에 일일이 계산하는 방법이 더 적합하다. 그럼에도 일일이 계산하는 방법은 N개의 정수와 M개의 범위가 주어지면, 시간복잡도가 O(NM)이 나오므로 시간이 많이 소요된다. 때문에, 주어진 데이터 세트의 특정 범위에 있는 데이터들에 대해 연산하는 데에 특화된 자료구조인 세그먼트 트리를 사용해야 한다.

우선 재귀 방식을 통해 세그먼트 트리를 생성하고, 구간의 곱을 구하는 쿼리와 임의의 인덱스의 값을 변경하는 쿼리를 수행할 함수를 만들어야 한다. 각 리프 노드에는 각 인덱스별 데이터가 들어가고, 리프 노드가 아닌 경우에는 각 구간별 수들의 곱이 들어간다. 구간의 곱을 구하는 쿼리는 현재 구간이 탐색 구간에 포함되는 지에 따라, 곱해도 값에 변화가 없는 1이나 현재 구간의 곱을 반환한다.

임의의 인덱스의 값을 변경하는 경우에는 하위 노드로 내려가면서, 변경할 인덱스의 값와 연산이 된 노드들의 값과 변경할 인덱스의 값을 모두 변경해야 한다. 합을 구하는 문제에 대해서는 차액만큼을 계산하면서 내려가면 되지만, 곱의 경우에는 연산 값에 변경하고자 하는 인덱스의 이전 값으로 나누고 변경할 값을 곱해주어야 한다. 이때, 변경할 인덱스의 이전 값이 0이라면 Zero Division Error가 발생하기 때문에, 0에 대해 예외 처리를 해주거나 값을 변경하고 연산을 처음부터 하면서 다시 올라가는 방법을 사용해야 한다.

해당 함수가 완성되면, 데이터 세트를 입력받아 세그먼트 트리를 만들고 쿼리를 입력받아 각각의 쿼리를 수행하면 된다. 수의 개수와 쿼리의 양이 적지 않으므로, 각 언어별로 Fast I/O를 구현하여 사용하는 것이 권장된다. 본 설명에서는 C++을 사용한 코드를 게재하였다.

세그먼트 트리(Segment Tree)란?

이 설명을 보기에 앞서 무려 BOJ에서 제공하는 자료에서 너무나도 잘 설명되어 있으니 꼭 읽어보기를 추천한다.

세그먼트 트리(segment tree)는 연속적으로 존재하는 여러 개의 데이터에 대해, 특정한 구간의 데이터를 쿼리(query)하는 데에 특화된 데이터 구조이다. 다시 말해, 주어진 데이터 세트에서 특정한 범위 내의 데이터들의 합이나 곱, 또는 최댓값이나 최솟값을 구하는 등의 탐색부터 데이터 변조까지의 과정을 보다 효율적으로 할 수 있게 만들어진 그래프 트리라고 볼 수 있다. 임의의 데이터 세트를 세그먼트 트리로 만드는 데에는 O(Nlog₂N)만큼의 시간복잡도를 가지고, 만들어진 세그먼트 트리로는 O(log₂N)에 쿼리를 수행할 수 있다. 쿼리가 O(log₂N)에 수행될 수 있는 이유는 N개의 데이터에 대해 log₂N의 높이를 가지기 때문인데, 세그먼트 트리는 Full Binary Tree의 형태를 띠므로 높이가 log₂N인 트리의 최대 노드의 개수는 Perfect Binary Tree가 만들어질 때인 2^(log₂N + 1) - 1개이다.

풀이 코드

#include <iostream>
#include <vector>
#include <cmath>
#define MOD 1000000007
/* 구간 곱을 1,000,000,007로 나눈 나머지 처리 */

using namespace std;

// 데이터의 개수와 쿼리의 개수(m: 수의 변경, k: 구간의 곱)
int n, m, k;
// 데이터 세트
long long arr[1000001];
// 구간의 곱에 대한 세그먼트 트리
vector<long long> mul_tree;

// 세그먼트 트리 생성
void init(int node, int start, int end) {
    // 리프 노드의 경우, 데이터 세트의 값 저장
    if(start == end)
        mul_tree[node] = arr[start];
    // 리프 노드가 아닌 노드의 경우, 각 구간의 모든 값의 곱 저장
    else {
        init(node * 2, start, (start + end) / 2);
        init(node * 2 + 1, (start + end) / 2 + 1, end);
        mul_tree[node] = (mul_tree[node * 2] * mul_tree[node * 2 + 1]) % MOD;
    }
}

// 구간의 곱을 반환하는 쿼리
long long query(int node, int start, int end, int left, int right) {
    // 현재 구간이 찾고자 하는 구간을 벗어나면 1 반환(곱해도 값의 변화 없음)
    if(start > right || end < left)
        return 1;

    // 현재 구간이 찾고자 하는 구간에 완전히 포함되면 현재 구간의 곱 반환
    if(start >= left && end <= right)
        return mul_tree[node];
    
    // 현재 구간이 찾고자 하는 구간에 일부 포함되어 있다면 두 부분으로 나누어 쿼리 수행
    long long left_mul = query(node * 2, start, (start + end) / 2, left, right);
    long long right_mul = query(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
    
    // 수행된 쿼리에 대해 1,000,000,007로 나눈 나머지 반환
    return (left_mul * right_mul) % MOD;
}

// 임의의 인덱스의 값을 변경하는 쿼리
void update(int node, int start, int end, int index, long long new_value) {
    // 현재 구간이 찾고자 하는 인덱스를 벗어나면 중단
    if(start > index || end < index)
        return;
    
    // 인덱스 구간에 속하는 리프 노드에 도착하면, 인덱스의 값 변경
    if(start == end) {
        mul_tree[node] = new_value;
        arr[index] = new_value;
    }
    // 인덱스가 구간에는 속하지만 아직 리프 노드가 아니라면, 두 부분으로 나누어 탐색
    // 만약 탐색이 종료되면 리프 노드가 아닌 노드의 구간의 곱 갱신
    else {
        update(node * 2, start, (start + end) / 2, index, new_value);
        update(node * 2 + 1, (start + end) / 2 + 1, end, index, new_value);
        mul_tree[node] = (mul_tree[node * 2] * mul_tree[node * 2 + 1]) % MOD;
    }
}

int main() {
    // Fast IO
    cin.tie(0);
    cout.tie(0);
    ios_base::sync_with_stdio(false);

    // 데이터의 개수와 쿼리의 개수 입력
    cin >> n >> m >> k;
    // 각 데이터 입력
    for(int i = 1; i <= n; i++)
        cin >> arr[i];
    
    // 세그먼트 트리 공간 설정 및 세그먼트 트리 생성
    int height = (int)ceil(log2(n));
    mul_tree = vector<long long>(1 << (height + 1));
    init(1, 1, n);

    // 각 쿼리를 입력받아 인덱스 값 변경 및 구간의 곱 출력
    int a, b, c;
    for(int i = 0; i < m + k; i++) {
        cin >> a >> b >> c;
        if(a == 1)
            update(1, 1, n, b, c);
        else if(a == 2)
            cout << query(1, 1, n, b, c) << '\n';
    }
    return 0;
}