PS

BOJ 28327 : 지그재그

lsj_haru 2024. 2. 20. 19:29

 

 

알고리즘

(segtree)

시간복잡도

$O(NlogN)$

풀이

$g(K)$, 즉 배열 중 $1\sim K$만을 고려하는 경우를 생각해보자.

 

$f(K, y, z)$가 가질 수 있는 최댓값은 구간 $[l, r]$ 중 $K$ 이하의 원소의 개수이다.
이를 각 원소 관점에서 생각해본다면, 가능한 $g(K)$의 최댓값은
모든 $K$ 이하의 자연수 $i$에 대해 $i$의 값을 가지는 위치를 $p(i)$라 할 때
$p(i) \times (N-p(i)+1)$의 합임을 알 수 있다.

 

$a_1, a_2, ..., a_N$에서 $K$ 이하의 원소만을 고른 부분수열을 $b_1, b_2, ..., b_K$라 하면,
$f(K, y, z)$의 실젯값은 위에서 구한 이론적인 값에서
$b_i < b_{i+1} < b_{i+2}$ 또는 $b_i > b_{i+1} > b_{i+2}$인 $i$의 개수를 뺀 값이다.
$(1 \leq i \leq K-2, y \leq p(b_i), p(b_{i+2}) \leq z)$
이 역시도 마찬가지로 각 연속한 세 원소의 관점에서
$p(b_i) \times (N-p(b_{i+2})+1)$의 합임을 알 수 있다.

 

이제 $g(K-1)$의 값을 기반으로 $g(K)$의 값을 구해보자.

  1. $K$의 위치 확인 : $O(1)$ - $(p(K))$
  2. 이론적인 값 update : $O(1)$ - $p(K) \times (N-p(K)+1)$ 추가
  3. $K$에 가장 가까운 좌/우 원소의 위치 확인 : $O(logN)$ - $O(logN)$ 안에 진행해야 하고, 계속 update된다는 특징이 있다.
    1. std::set
      매 순간마다 set에 index를 넣어두면,
      $upper$_$bound(p(K))$로 오른쪽은 쉽게 접근할 수 있다.
      왼쪽은 이전에 처리해놓은 좌우 배열을 기반으로 찾으면,
      오른쪽 원소의 ($K$ 추가 전) 왼쪽 원소임을 알 수 있다.
      이 때 $K$의 위치가 가장 오른쪽이라면,
      $upper$_$bound$ 함수가 제대로 작동하지 않으므로
      가장 오른쪽 원소의 위치를 따로 저장하여 예외처리 해야한다.
    2. segment tree
      매 순간마다 $p(K)$번 배열에 $p(K)$를 넣어두면,
      왼쪽 : $max(1, p(K)-1),$ 오른쪽 : $min(p(K)+1, N)$으로
      쉽게 찾을 수 있다.
  4. 좌우 배열 update : $O(1)$
    - 단순히 $K$와 $K$의 좌우 원소의 정보만 update하면 된다.
    이 때 좌우 원소가 없다면 $0/N+1$을 저장한다.
  5. 빼야 할 값 update : $O(1)$
    - $K$의 좌우 원소를 $L, R$이라 하고
    $L$의 왼쪽 원소를 $LL$, $R$의 오른쪽 원소를 $RR$이라 하자.
    $K$ 추가 시 사라지는 쌍은 $(LL, L, R), (L, R, RR)$이고
    새로 생기는 쌍은 $(LL, L, K), (L, K, R), (K, R, RR)$이다.
    이들이 강증가/감소 하는지의 여부에 따라 update해주면 된다.
    (이 때 존재하지 않는 원소가 포함된 쌍은 관리하지 않아야 한다.)

따라서 $g(K-1)$의 값을 기반으로 $g(K)$의 값을 $O(logN)$에 구할 수 있다.
코드에서는 편의상 $g(1), g(2)$는 직접 구했다.

코드

set 코드

#include <bits/stdc++.h>  
using namespace std;  
#define ll long long  

ll N, a[200005], p[200005], best, down, l[200005], r[200005];  
set<ll> b;  

int main()  
{  
    cin.tie(0)->sync_with_stdio(0);  
    cin>>N;  
    for (int i = 1; i <= N; i++) {  
        cin>>a[i]; p[a[i]] = i;  
    }    
    cout<<p[1]*(N-p[1]+1)<<"\n";  
    ll minz = min(p[1], p[2]), maxz = max(p[1], p[2]);  
    cout<<minz*(N-maxz+1)*2 + (minz+N-maxz+1)*(maxz-minz)<<"\n";  
    b.insert(p[1]); b.insert(p[2]);  
    best = p[1]*(N-p[1]+1) + p[2]*(N-p[2]+1);  
    l[minz] = 0; r[minz] = maxz; l[maxz] = minz; r[maxz] = N+1;  
    ll maxR = maxz;  
    for (int asd = 3; asd <= N; asd++) {  
        ll now = p[asd], ri = N+1, rri = N+1, li = 0, lli = 0;
        best += now*(N-now+1); b.insert(now);  
        if (maxR < now) {  
            li = maxR; lli = l[li]; maxR = now;  
        }        
        else {  
            ri = *b.upper_bound(now); rri = r[ri];
            li = l[ri]; lli = l[li];  
        }        
        r[li] = l[ri] = now; r[now] = ri; l[now] = li;
        if (li != 0) lli = l[li]; if (ri != N+1) rri = r[ri];  
        if (li == 0 && (a[now]-a[ri])*(a[ri]-a[rri]) > 0)
            down += now*(N-rri+1);  
        else if (ri == 0 && (a[lli]-a[li])*(a[li]-a[now]) > 0)
            down += lli*(N-now+1);  
        else if (asd == 3 && (a[li]-a[now])*(a[now]-a[ri]) > 0)
            down += li*(N-ri+1);  
        else if (lli == 0) {  
            if ((a[li]-a[ri])*(a[ri]-a[rri]) > 0)
                down -= li*(N-rri+1);  
            if ((a[li]-a[now])*(a[now]-a[ri]) > 0)
                down += li*(N-ri+1);  
            if ((a[now]-a[ri])*(a[ri]-a[rri]) > 0)
                down += now*(N-rri+1);  
        }        
        else if (rri == N+1) {  
            if ((a[lli]-a[li])*(a[li]-a[ri]) > 0)
                down -= lli*(N-ri+1);  
            if ((a[li]-a[now])*(a[now]-a[ri]) > 0)
                down += li*(N-ri+1);  
            if ((a[lli]-a[li])*(a[li]-a[now]) > 0)
                down += lli*(N-now+1);  
        }        
        else {  
            if ((a[li]-a[ri])*(a[ri]-a[rri]) > 0)
                down -= li*(N-rri+1);  
            if ((a[lli]-a[li])*(a[li]-a[ri]) > 0)
                down -= lli*(N-ri+1);  
            if ((a[li]-a[now])*(a[now]-a[ri]) > 0)
                down += li*(N-ri+1);  
            if ((a[now]-a[ri])*(a[ri]-a[rri]) > 0)
                down += now*(N-rri+1);  
            if ((a[lli]-a[li])*(a[li]-a[now]) > 0)
                down += lli*(N-now+1);  
        }        
        cout<<best-down<<"\n";  
    }    
    return 0;  
}

segment tree 코드

#include <bits/stdc++.h>  
using namespace std;  
#define ll long long  

ll N, a[200005], p[200005];  
pair<ll, ll> tree[800005];  

void add(int now, int s, int e, ll p)  
{  
    if (s > e || p < s || e < p) return;  
    tree[now]
    = {max(tree[now].first, p), min(tree[now].second, p)};  
    if (s == e) return;  
    int mid = s+e >> 1;  
    add(now*2, s, mid, p); add(now*2+1, mid+1, e, p);  
}  

pair<ll, ll> Mm(int now, int s, int e, int l, int r)  
{  
    if (s > e || r < s || e < l) return {0, N+1};  
    if (l <= s && e <= r) return tree[now];  
    int mid = s+e >> 1;  
    pair<ll, ll> ls = Mm(now*2, s, mid, l, r);
    pair<ll, ll> rs = Mm(now*2+1, mid+1, e, l, r);  
    return {max(ls.first, rs.first), min(ls.second, rs.second)};  
}  

ll l[200005], r[200005];  
ll best, down;  

int main()  
{  
    cin.tie(0)->sync_with_stdio(0);  
    cin>>N; for (int i = 1; i <= N*4; i++) tree[i] = {0, N+1};  
    for (int i = 1; i <= N; i++) {  
        cin>>a[i]; p[a[i]] = i;  
    }    
    cout<<p[1]*(N-p[1]+1)<<"\n";  
    ll minz = min(p[1], p[2]), maxz = max(p[1], p[2]);  
    cout<<minz*(N-maxz+1)*2 + (minz+N-maxz+1)*(maxz-minz)<<"\n";  
    add(1, 1, N, p[1]); add(1, 1, N, p[2]);  
    best = p[1]*(N-p[1]+1) + p[2]*(N-p[2]+1);  
    l[minz] = 0; r[minz] = maxz; l[maxz] = minz; r[maxz] = N+1;  
    for (int asd = 3; asd <= N; asd++) {  
        ll now = p[asd]; add(1, 1, N, now); best += now*(N-now+1);
        ll li = Mm(1, 1, N, 1, now-1).first;
        ll ri = Mm(1, 1, N, now+1, N).second;
        ll lli = 0, rri = N+1;  
        r[li] = l[ri] = now; r[now] = ri; l[now] = li;
        if (li != 0) lli = l[li]; if (ri != N+1) rri = r[ri];  
        if (li == 0 && (a[now]-a[ri])*(a[ri]-a[rri]) > 0)
            down += now*(N-rri+1);  
        else if (ri == 0 && (a[lli]-a[li])*(a[li]-a[now]) > 0)
            down += lli*(N-now+1);  
        else if (asd == 3 && (a[li]-a[now])*(a[now]-a[ri]) > 0)
            down += li*(N-ri+1);  
        else if (lli == 0) {  
            if ((a[li]-a[ri])*(a[ri]-a[rri]) > 0)
                down -= li*(N-rri+1);  
            if ((a[li]-a[now])*(a[now]-a[ri]) > 0)
                down += li*(N-ri+1);  
            if ((a[now]-a[ri])*(a[ri]-a[rri]) > 0)
                down += now*(N-rri+1);  
        }        
        else if (rri == N+1) {  
            if ((a[lli]-a[li])*(a[li]-a[ri]) > 0)
                down -= lli*(N-ri+1);  
            if ((a[li]-a[now])*(a[now]-a[ri]) > 0)
                down += li*(N-ri+1);  
            if ((a[lli]-a[li])*(a[li]-a[now]) > 0)
                down += lli*(N-now+1);  
        }        
        else {  
            if ((a[li]-a[ri])*(a[ri]-a[rri]) > 0)
                down -= li*(N-rri+1);  
            if ((a[lli]-a[li])*(a[li]-a[ri]) > 0)
                down -= lli*(N-ri+1);  
            if ((a[li]-a[now])*(a[now]-a[ri]) > 0)
                down += li*(N-ri+1);  
            if ((a[now]-a[ri])*(a[ri]-a[rri]) > 0)
                down += now*(N-rri+1);  
            if ((a[lli]-a[li])*(a[li]-a[now]) > 0)
                down += lli*(N-now+1);  
        }        
        cout<<best-down<<"\n";  
    }    
    return 0;  
}

'PS' 카테고리의 다른 글

BOJ 29335 : Coloring  (2) 2024.03.04
BOJ 28326 : 고기 파티  (0) 2024.02.23
BOJ 12011 : Splitting the Field  (0) 2024.02.11
BOJ 11993 : Circular Barn (Gold)  (0) 2024.02.08
BOJ 11982 : Angry Cows (Gold)  (1) 2024.02.08