遅延セグツリーの仕組みを知りたい
初めましてuetaです。Atcoder やcodeforcesで競技プログラミングやっています。
今回はアルゴリズム、遅延セグツリーを勉強しました。
これの理解のために時間を費やしたのでそれをまとめていこうと思います。
セグツリーとは?
前処理を配列長Nに対してO(N)で行うと
各クエリの取得、1点更新が O(logN)で実行できる
学ぶなら algo-logic.info
遅延セグツリーとは
遅延セグツリーはセグツリーにおいて、区間更新が可能になっているデータ構造。
遅延セグツリーのアルゴリズムを理解するにはセグメントツリーの理解が必要になる。
詳しくない方は勉強してから読んでいただけると幸いです。
本編
今回はACL(atcoder library)を用いて示しながら示していこうと思う。 今回は区間最小値、区間更新とする。以下では配列長をNとする。
データの構造
基本的には通常のセグメントツリーと同じ構造
そこに lz という 遅延を保持する部分を別に持つことでうまく処理をこなす
遅延セグツリーの重要なコマンドは以下の二つ
区間更新apply(int l,int r,int x)
[ l , r )の範囲をxに更新する 計算量 O(logN)
No.1: lからrの範囲をトップダウンにこれまでの遅延記憶 lz を更新
一番上から、たまった遅延部分lzを更新を伝播させていく。 lzの値があるときにのみ、下に更新する必要がある。その時に、更新した値は初期値に戻す。
この更新数は 2*logN個以下であるため O(logN)
No. 2: lからrの範囲をseg区間をxで更新
lからrをxにする。通常のセグメントツリーと同様に区間を更新する。その時にlzの値も更新する。
これはもちろん logN個以下 O(logN)
No. 3: lからrの範囲をボトムアップに更新
下から、更新を伝播させていく。 セグメントツリーの一点更新と同じ要領で区間を更新する。
この更新数は 2*logN個以下であるため O(logN)
遅延セグツリーapplyコード(区間更新の部分をACLから抜き出したのもの)
void apply(int l, int r, F f) { assert(0 <= l && l <= r && r <= _n); if (l == r) return; l += size; r += size; // No. 1 for (int i = log; i >= 1; i--) { if (((l >> i) << i) != l) push(l >> i); if (((r >> i) << i) != r) push((r - 1) >> i); } // No. 2 { int l2 = l, r2 = r; while (l < r) { if (l & 1) all_apply(l++, f); if (r & 1) all_apply(--r, f); l >>= 1; r >>= 1; } l = l2; r = r2; } // No. 3 for (int i = 1; i <= log; i++) { if (((l >> i) << i) != l) update(l >> i); if (((r >> i) << i) != r) update((r - 1) >> i); } }
gifにするとこうなる。
1-7を5に更新する
続けて5-12を6に更新する
区間取得 prod(int l,int r)
[ l , r )の範囲を更新する 計算量 O(logN)
No.1: lからrの範囲をトップダウンにこれまでの遅延記憶 lz を更新
一番上から、たまった遅延部分lzを更新を伝播させていく。 lzの値があるときにのみ、下に更新する必要がある。その時に、更新した値は初期値に戻す。
この更新数は 2*logN個以下であるため O(logN)
No. 2: lからrの範囲をseg区間を取得
通常のセグメントツリーと同様に区間を取得する。
これはもちろん logN個以下 O(logN)
遅延セグツリーprodコード(区間取得の部分をACLから抜き出したのもの)
S prod(int l, int r) { assert(0 <= l && l <= r && r <= _n); if (l == r) return e(); l += size; r += size; // No. 1 for (int i = log; i >= 1; i--) { if (((l >> i) << i) != l) push(l >> i); if (((r >> i) << i) != r) push(r >> i); } // No. 2 S sml = e(), smr = e(); while (l < r) { if (l & 1) sml = op(sml, d[l++]); if (r & 1) smr = op(d[--r], smr); l >>= 1; r >>= 1; } return op(sml, smr); }
gifで確認。
続けて2-7を取得する
続けて11-15を取得する
最後に
遅延の流れをようやく理解できた。 コードのすべては以下に示す。
ぜひ参考にしてください。
遅延セグツリーに関わるACLのコード
// https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_2_F&lang=ja #include <bits/stdc++.h> #include <cstdint> using namespace std; //https://github.com/atcoder/ac-library int ceil_pow2(int n) { int x = 0; while ((1U << x) < (unsigned int)(n)) x++; return x; } template <class S, S (*op)(S, S), S (*e)(), class F, S (*mapping)(F, S), F (*composition)(F, F), F (*id)()> struct lazy_segtree { public: lazy_segtree() : lazy_segtree(0) {} lazy_segtree(int n) : lazy_segtree(vector<S>(n, e())) {} lazy_segtree(const vector<S>& v) : _n(int(v.size())) { log = ceil_pow2(_n); size = 1 << log; d = vector<S>(2 * size, e()); lz = vector<F>(size, id()); for (int i = 0; i < _n; i++) d[size + i] = v[i]; for (int i = size - 1; i >= 1; i--) { update(i); } } void set(int p, S x) { assert(0 <= p && p < _n); p += size; for (int i = log; i >= 1; i--) push(p >> i); d[p] = x; for (int i = 1; i <= log; i++) update(p >> i); } S get(int p) { assert(0 <= p && p < _n); p += size; for (int i = log; i >= 1; i--) push(p >> i); return d[p]; } S prod(int l, int r) { assert(0 <= l && l <= r && r <= _n); if (l == r) return e(); l += size; r += size; for (int i = log; i >= 1; i--) { if (((l >> i) << i) != l) push(l >> i); if (((r >> i) << i) != r) push(r >> i); } S sml = e(), smr = e(); while (l < r) { if (l & 1) sml = op(sml, d[l++]); if (r & 1) smr = op(d[--r], smr); l >>= 1; r >>= 1; } return op(sml, smr); } S all_prod() { return d[1]; } void apply(int p, F f) { assert(0 <= p && p < _n); p += size; for (int i = log; i >= 1; i--) push(p >> i); d[p] = mapping(f, d[p]); for (int i = 1; i <= log; i++) update(p >> i); } void apply(int l, int r, F f) { assert(0 <= l && l <= r && r <= _n); if (l == r) return; l += size; r += size; for (int i = log; i >= 1; i--) { if (((l >> i) << i) != l) push(l >> i); if (((r >> i) << i) != r) push((r - 1) >> i); } { int l2 = l, r2 = r; while (l < r) { if (l & 1) all_apply(l++, f); if (r & 1) all_apply(--r, f); l >>= 1; r >>= 1; } l = l2; r = r2; } for (int i = 1; i <= log; i++) { if (((l >> i) << i) != l) update(l >> i); if (((r >> i) << i) != r) update((r - 1) >> i); } } template <bool (*g)(S)> int max_right(int l) { return max_right(l, [](S x) { return g(x); }); } template <class G> int max_right(int l, G g) { assert(0 <= l && l <= _n); assert(g(e())); if (l == _n) return _n; l += size; for (int i = log; i >= 1; i--) push(l >> i); S sm = e(); do { while (l % 2 == 0) l >>= 1; if (!g(op(sm, d[l]))) { while (l < size) { push(l); l = (2 * l); if (g(op(sm, d[l]))) { sm = op(sm, d[l]); l++; } } return l - size; } sm = op(sm, d[l]); l++; } while ((l & -l) != l); return _n; } template <bool (*g)(S)> int min_left(int r) { return min_left(r, [](S x) { return g(x); }); } template <class G> int min_left(int r, G g) { assert(0 <= r && r <= _n); assert(g(e())); if (r == 0) return 0; r += size; for (int i = log; i >= 1; i--) push((r - 1) >> i); S sm = e(); do { r--; while (r > 1 && (r % 2)) r >>= 1; if (!g(op(d[r], sm))) { while (r < size) { push(r); r = (2 * r + 1); if (g(op(d[r], sm))) { sm = op(d[r], sm); r--; } } return r + 1 - size; } sm = op(d[r], sm); } while ((r & -r) != r); return 0; } private: int _n, size, log; vector<S> d; vector<F> lz; void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); } void all_apply(int k, F f) { d[k] = mapping(f, d[k]); if (k < size) lz[k] = composition(f, lz[k]); } void push(int k) { all_apply(2 * k, lz[k]); all_apply(2 * k + 1, lz[k]); lz[k] = id(); } }; using S = long long; using F = long long; const S INF = (1LL<<31) -1 ; const F ID = (1LL<<31) -1; S op(S a, S b){ return std::min(a, b); } S e(){ return INF; } S mapping(F f, S x){ return (f == ID ? x : f); } F composition(F f, F g){ return (f == ID ? g : f); } F id(){ return ID; } int main(){ int n,q; cin >> n >> q; lazy_segtree<S, op, e, F, mapping, composition, id> seg(n); while (q--) { int r,s,t,x; cin >> r; if(r==0){ cin >> s >> t >> x; seg.apply(s,t+1,F{x}); }else{ cin >> s >> t; cout << seg.prod(s,t+1) << endl; } } return 0; }
参考
非再帰版の遅延評価セグメント木の実装メモ - 日々drdrする人のメモ