segment tree가 뭔가요?
segment tree(이하 segtree)는 주어진 배열에서
- 특정 원소 수정
- 구간 단위 연산의 결과(ex: 구간 합, 구간 최댓값)
를 $O(logN)$에 수행하게 해주는 알고리즘입니다.
기본 원리
작동하는 기본 원리는 이분 탐색과 비슷한 듯 다릅니다.
위의 그림에서와 같이, 현재 구간에 대해 반으로 나누며
필요한 경우 새로운 구간에 대해 재귀적으로 반복하는 것입니다.
각 범위별로 구간 단위 연산의 결과를 저장해놓는다 하면,
배열의 크기가 $N$일 때 $4N$ 이하의 크기의 배열에 저장할 수 있습니다.
저장 원리는 현재 범위를 저장할 번호가 $k$라 할 때,
앞쪽 절반 배열은 $2k$, 뒤쪽 절반 배열은 $2k+1$에 저장하는 것입니다.
예시로 배열의 크기가 8이라 하면, 연산 결과를 $tree$에 저장할 때
$(1, 8)$은 $tree[1]$에, 반으로 나눈 $(1, 4)$, $(5, 8)$은 각각 $tree[2], tree[3]$에
저장하는 방식으로 저장합니다.
특정 원소 수정
편의상 segtree의 구간 단위 연산이 구간 합이라 하겠습니다.
배열에서 특정 원소 $a[p]$에 $x$를 더한다고 한다면,
- $p$가 현재 확인중인 배열의 범위 $(s, e)$에 속해있지 않다면 return,
- $s=e=p$이면 $tree$ 값에 $x$만 더하고 return,
- 이외의 경우 반으로 나눈 두 배열에 대해 다시 재귀적으로 진행합니다. 그 후 두 반으로 나눈 배열의 구간 합을 더해주면 현재 배열의 구간 합이 되겠죠?
코드는 다음과 같습니다.
void edit(int now, int s, int e, int p, int x)
//now: tree 배열에 저장할 번호, s, e: 현재 관리할 배열의 시작/끝 값
//p: 값을 바꿀 원소의 위치, x: 값을 바꿀 원소에 더할 값
{
if (s > e || p < s || e < p) return; //범위가 안맞으면 return
if (s == e) {
tree[now] += x; return;
} //s=e=p이면 x 더하고 return
int mid = (s+e)/2;
edit(now*2, s, mid, p, x);
edit(now*2+1, mid+1, e, p, x); //반반 나눠서 재귀적으로 확인
tree[now] = tree[now*2] + tree[now*2+1];
//반으로 나눈 두 배열의 부분 합의 합
}
구간 단위 연산의 결과
계속 구간 합으로 설명하겠습니다.
현재 확인중인 배열의 범위를 $(s, e)$, 구간 합을 구할 범위를 $(f, r)$이라 하면,
- $(s, e)$와 $(f, r)$이 겹치지 않으면 return 0,
- $(s, e) \subset (f, r)$이면 return $tree[now]$,
- 이외의 경우(어중간하게 겹친 경우) 반으로 나눈 두 배열에 대해 재귀적으로 진행합니다. 그 후 반으로 나눈 배열의 return 결과의 합을 return합니다.
코드는 다음과 같습니다.
int sum(int now, int s, int e, int f, int r)
//now: tree 배열에 저장할 번호, s, e: 현재 관리할 배열의 시작/끝 값
//f, r: 구간 합을 구할 범위의 시작/끝 값
{
if (s > e || r < s || e < f) return 0; //안 겹치면 return 0
if (f <= s && e <= r) return tree[now];
//(s, e)가 (f, r)에 완전히 속하면 현재 tree값 return
int mid = (s+e)/2;
return sum(now*2, s, mid, f, r) + sum(now*2+1, mid+1, e, f, r);
//반으로 나눈 두 배열의 return값의 합 return
}
초기 segtree 만들기
주어진 배열에 대해 segtree를 만들고 시작해야 한다면,
앞에서 설명한 특정 원소 수정을 $N$번 진행하여 만들어도 됩니다.
이 경우 시간복잡도는 $O(NlogN)$입니다.
그러나 비슷한 방식으로 $O(N)$에 구현할 수도 있습니다.
(원리는 비슷하므로 코드를 보며 스스로 생각해보시길 바랍니다.)
void mt(int now, int s, int e)
{
if (s > e) return;
if (s == e) {
tree[now] = a[s];
return;
}
int mid = (s+e)/2;
mt(now*2, s, mid); mt(now*2+1, mid, e);
tree[now] = tree[now*2] + tree[now*2+1];
}
개선
연산이 평균, 2배 또는 2배 후 +1 뿐이므로 비트연산자로 구현 가능합니다.
(s+e)/2 == s+e>>1
now*2 == now<<1
now*2+1 == now<<1|1
생각해보면 $s>e$인 경우 return할 필요는 없습니다.
해당 경우가 생기는 경우는 $mid+1>e$인 경우가 유일한데,
이미 $s=e$인 경우 return하므로 해당 경우는 존재하지 않습니다.
문제 적용
2042번 문제의 예시 코드입니다.
구간 합 구하기
#include <bits/stdc++.h>
using namespace std;
#define ll long long
int N, M, K;
ll a[1000005], tree[4000005];
void edit(int now, int s, int e, int p, ll x)
{
if (p < s || e < p) return;
if (s == e) {tree[now] = x; return;}
int mid = s+e >> 1;
edit(now<<1, s, mid, p, x);
edit(now<<1|1, mid+1, e, p, x);
tree[now] = tree[now<<1] + tree[now<<1|1];
}
ll sum(int now, int s, int e, int f, int r)
{
if (r < s || e < f) return 0;
if (f <= s && e <= r) return tree[now];
int mid = s+e >> 1;
return sum(now<<1, s, mid, f, r)
+ sum(now<<1|1, mid+1, e, f, r);
}
void mt(int now, int s, int e)
{
if (s == e) {tree[now] = a[s]; return;}
int mid = s+e >> 1;
mt(now<<1, s, mid);
mt(now<<1|1, mid+1, e);
tree[now] = tree[now<<1] + tree[now<<1|1];
}
int main()
{
cin.tie(0)->sync_with_stdio(0);
cin>>N>>M>>K;
for (int i = 1; i <= N; i++) cin>>a[i];
mt(1, 1, N);
for (int i = 1; i <= M+K; i++) {
ll t, x, y; cin>>t>>x>>y;
if (t&1) edit(1, 1, N, x, y);
else cout<<sum(1, 1, N, x, y)<<"\n";
} return 0;
}
시간복잡도는 $O(N+(M+K)logN)$입니다.
추가적인 스킬들은 추후 포스팅하겠습니다.
'Algorithm' 카테고리의 다른 글
Lazy Segment Tree (1) | 2024.02.10 |
---|