알고리즘 문제를 풀다보면 특정 구간에 합이나 최대값, 최소값 등 정보를 얻어야 할 때가 있다.


데이터의 숫자가 얼마 안되면 배열에 값을 넣어두고 특정 구간을 차례로 탐색하며 값을 연산하여 구하면된다.


아래 예를 보자.8개의 숫자 중 특정 구간의 합을 구하는 문제이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include<stdio.h>
 
int data[8= { 1,22,4,8,11,23,6,14 }; //데이터
 
int segSum(int start, int end) { //start부터 end까지의 합
    int sum = 0;
    for (int i = start-1; i < end; i++) {
        sum += data[i];        
    }
    return sum;
}
 
int main() {
    //부분합 구하기
    printf("%d\n", segSum(18));
    printf("%d\n", segSum(88));
    printf("%d\n", segSum(13));
    printf("%d\n", segSum(35));
}
 
cs


데이터의 갯수가  적고 조회하는 횟수도 적다면 위와같은 코드는 문제가 될 것 이 없다.

하지만 데이터의 갯수가 100,000 개이고 부분합읠 조회하는 횟수가 100,000 이라고 하면 최악의 경우 

for문의 반복 횟수가 100,000 * 100,000 이란 어마어마한 숫자가 된다. 


그래서 이를 해결하기 위해 세그먼트 트리(Segment Tree)라는 바이너리 트리를 이용한다.


leaf 노드(가장 끝단에 있는 노드들)에 데이터를 저장하고 부모노드는 자식노드의 합 또는 최대값, 최소값 등을 저장한다.



위의 데이터를 세그먼트 트리로 만들면 아래와 같다.

각 노드에는 leaf노드들의 시작과 끝 번호를 가지고 있어서 해당 구간의 데이터를 요청하면 lead노드까지 탐색해보지 않고 값을 반환 하기 때문에 탐색시간이 빨라진다.


아래 코드는 위의 배열을 탐색하는 코드를 세그먼트 트리를 이용하여 구현한 것이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include<stdio.h>
 
int data[8= { 1,22,4,8,11,23,6,14};
int index = 0;
 
struct SegTree {
    SegTree *left;    //왼쪽 노드
    SegTree *right; //오른쪽 노드
    int s, e, m;    //시작, 끝, 중간
    int sum;        //자식들의 합
    SegTree() {        //빈생성자
        s = e = m = sum = 0;
        left = right = NULL;
    }
    SegTree(int start, int end) { //생성자
        s = start;
        e = end;
        m = (s + e) / 2//중간값 계산
        if (start == end) { //leaf노드
            sum = data[index++]; //데이터 저장
            return//leaf노드라 더이상 쪼갤필요없어서 종료
        }
        //쪽개기
        left = new SegTree(s, m); //왼쪽 노드
        right = new SegTree(m + 1, e); //오른쪽 노드
        sum = left->sum + right->sum; //자식들의 합
    }
    int serch(int start, int end) {        //start부터 end까지의 합 계산
        if (start > e || s > end)    return 0;    //범위 초과
        if (start <= s && e <= end) return sum; //범위 안에 있음
        return right->serch(start, end) + left->serch(start, end);
    }    
};
 
int main() {
    SegTree *tree = new SegTree(18); //세그먼트 트리 생성
    printf("%d\n", tree->serch(88));
    printf("%d\n", tree->serch(13));
    printf("%d\n", tree->serch(18));
    printf("%d\n", tree->serch(35));
}
 
cs


3에서 5의 합을 구하는 것을 아래 그림을 참고하도록 하자.

구하려는 3~5의 구간을 가지고 노드가 가지고 있는 s와 e 값을 가지고 비교하며 값을 구한다.



Posted by 꿈만은공돌
,