정의
세그먼트 트리란 구간의 값을 전처리하여 특정 구간의 쿼리를 빠르게 수행하는 자료구조이다.
- 빈번한 구간 연산이 일어날 때 유용하다.
- 구간의 합, 곱, 최솟값, 최댓값, 최대공약수 등을 빠르게 구할 때 사용한다.
- 세그먼트 트리와 인덱스 트리를 이용하면 O(lgN)으로 구간의 값을 구할 수 있다.
세그먼트 트리 | 인덱스 트리 | |
접근 방식 | Top-down | Bottom-up |
구조 | 포화 이진트리 (2^n 일때 완전 이진트리) |
완전 이진트리 |
특징 | 범용성이 높아 응용하기 좋음 | 구현 난이도가 쉽고 직관적임 |
세그먼트 트리 예시 코드
// 백준 2042번 구간 합 구하기
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
int si = 1;
ll arr[1024*1024];
ll seg[1024*1024*2];
// 세그먼트 트리 초기화
long long init(int l, int r, int node){
if(l == r) return seg[node] = arr[l];
int mid = (l+r)/2;
return seg[node] = init(l, mid, node*2) + init(mid+1, r, node*2+1);
}
// 값 변경
void update(int l, int r, int node, int idx, ll n) {
if(idx < l || r < idx) return;
seg[node] += n;
if(l != r) {
int mid = (l+r)/2;
update(l, mid ,node*2, idx, n);
update(mid+1, r, node*2+1, idx, n);
}
}
// 구간의 합 구하기
ll sum(int l, int r, int node, int st, int en) {
if(en < l || r < st) return 0;
if(st <= l && r <= en) return seg[node];
int mid = (l+r)/2;
return sum(l, mid, node*2, st, en) + sum(mid+1, r, node*2+1, st, en);
}
int main() {
int n, m, k;
cin >> n >> m >> k;
while(si < n) si <<= 1;
for(int i = 0; i < n; i++)
cin >> arr[i];
init(0, n-1, 1);
ll a,b,c;
for(int i = 0; i < m+k; i++) {
cin >> a >> b >> c;
if(a == 1) {
update(0, n-1, 1, b-1, c-arr[b-1]);
arr[b-1] = c;
}
else
cout << sum(0, n-1, 1, b-1, c-1) << '\n';
}
return 0;
}
인덱스 트리 예시 코드
// 백준 2042번 구간 합 구하기
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
int si = 1;
ll seg[1024*1024*2];
// 값 변경
void update(int i, ll n) {
i += si;
seg[i] = n;
while(i > 1) {
i /= 2;
seg[i] = seg[i*2] + seg[i*2+1];
}
}
// 구간의 합 구하기
ll sum(int a, int b) {
a += si;
b += si;
ll ret=0;
while(a <= b) {
if(a & 1) ret += seg[a];
if(!(b & 1)) ret += seg[b];
a = (a + 1) / 2;
b = (b - 1) / 2;
}
return ret;
}
int main() {
int n, m, k;
cin >> n >> m >> k;
while(si < n)
si <<= 1;
for(int i = si; i < si+n; i++)
cin >> seg[i];
// 인덱스 트리 초기화
for(int i = si-1; i >= 1; i--)
seg[i] = seg[i*2] + seg[i*2+1];
ll a, b, c;
for(int i = 0; i < m+k; i++) {
cin >> a >> b >> c;
if(a == 1) update(b-1, c);
if(a == 2) cout << sum(b-1, c-1) << '\n';
}
return 0;
}