遅延セグツリーの仕組みを知りたい

初めましてuetaです。Atcodercodeforces競技プログラミングやっています。

今回はアルゴリズム、遅延セグツリーを勉強しました。

これの理解のために時間を費やしたのでそれをまとめていこうと思います。

セグツリーとは?

任意の範囲の最大値最小値、区間和などを得るアルゴリズム

前処理を配列長Nに対してO(N)で行うと

各クエリの取得、1点更新が O(logN)で実行できる

学ぶなら algo-logic.info

遅延セグツリーとは

遅延セグツリーはセグツリーにおいて、区間更新が可能になっているデータ構造。

遅延セグツリーのアルゴリズムを理解するにはセグメントツリーの理解が必要になる。

詳しくない方は勉強してから読んでいただけると幸いです。

本編

今回はACL(atcoder library)を用いて示しながら示していこうと思う。 今回は区間最小値区間更新とする。以下では配列長をNとする。

データの構造

基本的には通常のセグメントツリーと同じ構造

そこに lz という 遅延を保持する部分を別に持つことでうまく処理をこなす

遅延セグツリーの重要なコマンドは以下の二つ

  • 区間更新 apply(int l,int r,int x)

  • 区間クエリ実行 prod(int l,int r)

区間更新apply(int l,int r,int x)

[ l , r )の範囲をxに更新する 計算量 O(logN)

No.1: lからrの範囲をトップダウンにこれまでの遅延記憶 lz を更新

一番上から、たまった遅延部分lzを更新を伝播させていく。 lzの値があるときにのみ、下に更新する必要がある。その時に、更新した値は初期値に戻す。

この更新数は 2*logN個以下であるため O(logN)

No. 2: lからrの範囲をseg区間をxで更新

lからrをxにする。通常のセグメントツリーと同様に区間を更新する。その時にlzの値も更新する。

これはもちろん logN個以下 O(logN)

No. 3: lからrの範囲をボトムアップに更新

下から、更新を伝播させていく。 セグメントツリーの一点更新と同じ要領で区間を更新する。

この更新数は 2*logN個以下であるため O(logN)


遅延セグツリーapplyコード(区間更新の部分をACLから抜き出したのもの)

    void apply(int l, int r, F f) {
        assert(0 <= l && l <= r && r <= _n);
        if (l == r) return;

        l += size;
        r += size;

// No. 1
        for (int i = log; i >= 1; i--) {
            if (((l >> i) << i) != l) push(l >> i);
            if (((r >> i) << i) != r) push((r - 1) >> i);
        }
// No. 2
        {
            int l2 = l, r2 = r;
            while (l < r) {
                if (l & 1) all_apply(l++, f);
                if (r & 1) all_apply(--r, f);
                l >>= 1;
                r >>= 1;
            }
            l = l2;
            r = r2;
        }
// No. 3
        for (int i = 1; i <= log; i++) {
            if (((l >> i) << i) != l) update(l >> i);
            if (((r >> i) << i) != r) update((r - 1) >> i);
        }
    }

gifにするとこうなる。

1-7を5に更新する

apply実行1


続けて5-12を6に更新する

apply実行2

区間取得 prod(int l,int r)

[ l , r )の範囲を更新する 計算量 O(logN)

No.1: lからrの範囲をトップダウンにこれまでの遅延記憶 lz を更新

一番上から、たまった遅延部分lzを更新を伝播させていく。 lzの値があるときにのみ、下に更新する必要がある。その時に、更新した値は初期値に戻す。

この更新数は 2*logN個以下であるため O(logN)

No. 2: lからrの範囲をseg区間を取得

通常のセグメントツリーと同様に区間を取得する。

これはもちろん logN個以下 O(logN)


遅延セグツリーprodコード(区間取得の部分をACLから抜き出したのもの)

    S prod(int l, int r) {
        assert(0 <= l && l <= r && r <= _n);
        if (l == r) return e();

        l += size;
        r += size;
// No. 1
        for (int i = log; i >= 1; i--) {
            if (((l >> i) << i) != l) push(l >> i);
            if (((r >> i) << i) != r) push(r >> i);
        }
// No. 2
        S sml = e(), smr = e();
        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }

        return op(sml, smr);
    }

gifで確認。

続けて2-7を取得する

prod実行1


続けて11-15を取得する

prod実行2

最後に

遅延の流れをようやく理解できた。 コードのすべては以下に示す。

ぜひ参考にしてください。

遅延セグツリーに関わるACLのコード

// https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_2_F&lang=ja
#include <bits/stdc++.h>
#include <cstdint>
using namespace std;

//https://github.com/atcoder/ac-library
int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}
template <class S,
          S (*op)(S, S),
          S (*e)(),
          class F,
          S (*mapping)(F, S),
          F (*composition)(F, F),
          F (*id)()>
struct lazy_segtree {
  public:
    lazy_segtree() : lazy_segtree(0) {}
    lazy_segtree(int n) : lazy_segtree(vector<S>(n, e())) {}
    lazy_segtree(const vector<S>& v) : _n(int(v.size())) {
        log = ceil_pow2(_n);
        size = 1 << log;
        d = vector<S>(2 * size, e());
        lz = vector<F>(size, id());
        for (int i = 0; i < _n; i++) d[size + i] = v[i];
        for (int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }

    void set(int p, S x) {
        assert(0 <= p && p < _n);
        p += size;
        for (int i = log; i >= 1; i--) push(p >> i);
        d[p] = x;
        for (int i = 1; i <= log; i++) update(p >> i);
    }

    S get(int p) {
        assert(0 <= p && p < _n);
        p += size;
        for (int i = log; i >= 1; i--) push(p >> i);
        return d[p];
    }

    S prod(int l, int r) {
        assert(0 <= l && l <= r && r <= _n);
        if (l == r) return e();

        l += size;
        r += size;

        for (int i = log; i >= 1; i--) {
            if (((l >> i) << i) != l) push(l >> i);
            if (((r >> i) << i) != r) push(r >> i);
        }

        S sml = e(), smr = e();
        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }

        return op(sml, smr);
    }

    S all_prod() { return d[1]; }

    void apply(int p, F f) {
        assert(0 <= p && p < _n);
        p += size;
        for (int i = log; i >= 1; i--) push(p >> i);
        d[p] = mapping(f, d[p]);
        for (int i = 1; i <= log; i++) update(p >> i);
    }
    void apply(int l, int r, F f) {
        assert(0 <= l && l <= r && r <= _n);
        if (l == r) return;

        l += size;
        r += size;

        for (int i = log; i >= 1; i--) {
            if (((l >> i) << i) != l) push(l >> i);
            if (((r >> i) << i) != r) push((r - 1) >> i);
        }

        {
            int l2 = l, r2 = r;
            while (l < r) {
                if (l & 1) all_apply(l++, f);
                if (r & 1) all_apply(--r, f);
                l >>= 1;
                r >>= 1;
            }
            l = l2;
            r = r2;
        }

        for (int i = 1; i <= log; i++) {
            if (((l >> i) << i) != l) update(l >> i);
            if (((r >> i) << i) != r) update((r - 1) >> i);
        }
    }

    template <bool (*g)(S)> int max_right(int l) {
        return max_right(l, [](S x) { return g(x); });
    }
    template <class G> int max_right(int l, G g) {
        assert(0 <= l && l <= _n);
        assert(g(e()));
        if (l == _n) return _n;
        l += size;
        for (int i = log; i >= 1; i--) push(l >> i);
        S sm = e();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!g(op(sm, d[l]))) {
                while (l < size) {
                    push(l);
                    l = (2 * l);
                    if (g(op(sm, d[l]))) {
                        sm = op(sm, d[l]);
                        l++;
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    template <bool (*g)(S)> int min_left(int r) {
        return min_left(r, [](S x) { return g(x); });
    }
    template <class G> int min_left(int r, G g) {
        assert(0 <= r && r <= _n);
        assert(g(e()));
        if (r == 0) return 0;
        r += size;
        for (int i = log; i >= 1; i--) push((r - 1) >> i);
        S sm = e();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!g(op(d[r], sm))) {
                while (r < size) {
                    push(r);
                    r = (2 * r + 1);
                    if (g(op(d[r], sm))) {
                        sm = op(d[r], sm);
                        r--;
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while ((r & -r) != r);
        return 0;
    }

  private:
    int _n, size, log;
    vector<S> d;
    vector<F> lz;

    void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
    void all_apply(int k, F f) {
        d[k] = mapping(f, d[k]);
        if (k < size) lz[k] = composition(f, lz[k]);
    }
    void push(int k) {
        all_apply(2 * k, lz[k]);
        all_apply(2 * k + 1, lz[k]);
        lz[k] = id();
    }
};

using S = long long;
using F = long long;
const S INF = (1LL<<31) -1 ;
const F ID = (1LL<<31) -1;
S op(S a, S b){ return std::min(a, b); }
S e(){ return INF; }
S mapping(F f, S x){ return (f == ID ? x : f); }
F composition(F f, F g){ return (f == ID ? g : f); }
F id(){ return ID; }


int main(){
    int n,q;
    cin >> n >> q;
    
    lazy_segtree<S, op, e, F, mapping, composition, id> seg(n);
    while (q--)
    {
        int r,s,t,x;
        cin >> r;
        if(r==0){
            cin >> s >> t >> x;
            seg.apply(s,t+1,F{x});
        }else{
            cin >> s >> t;
            cout << seg.prod(s,t+1) << endl;
        }
    }
    
    return 0;
}

参考

非再帰版の遅延評価セグメント木の実装メモ - 日々drdrする人のメモ

遅延評価セグメント木も完全に理解する #競技プログラミング - Qiita

GitHub - atcoder/ac-library: AtCoder Library