树链剖分(轻重链剖分)原理与实现

树链剖分(轻重链剖分)原理与实现

概述

树链剖分是一种将树上的链划分为轻重链,从而实现降低复杂度的处理方式。剖分后的树,单次查询LCA的复杂度为O(\log n),与Tarjan-LCA算法与倍增算法复杂度相同,但是树剖可以在线,非常好用。

原理与实现

轻重链

我们把一个有根树一个点子树最大的儿子叫做这个点的重儿子,重儿子连接成的链为重链,而其他的链叫做轻链。每个点都有重儿子,因此重链可能并不连续。但是,可以证明一个点到根的路径上的链数为O(\log n)级别的。
下面是一棵剖分好了的树,图片来自[知识点]树链剖分 – jinkun113 – 博客园,感谢原作者。
180315b 1 - 树链剖分(轻重链剖分)原理与实现
如果我们每次从查询LCA的两个点上跳到他们所在的链顶,直到他们的链顶相同,就能够完成查询LCA的工作。
下面是树剖的两个DFS,用于处理子树信息与划分轻重链。

// 处理子树信息
void dfs1(int u) {
    int mx = -1;
    for(int i = head[u]; i != -1; i = gra[i].nxt) {
        int v = gra[i].to;
        if(v == fa[u]) continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        siz[u]++;
        dfs1(v);
        siz[u] += siz[v];
        if(siz[v] > mx) {
            mx = siz[v];
            son[u] = v;
        }
    } 
}

// 处理链信息
void dfs2(int u, int tp) {
    top[u] = tp;
    if(son[u] != 0) {
        // 优先遍历重儿子
        dfs2(son[u], tp);
    }
    for(int i = head[u]; i != -1; i = gra[i].nxt) {
        int v = gra[i].to;
        if(v == son[u] || v == fa[u]) continue;
        dfs2(v, v);
    }
}

查询LCA可以使用下面的实现。

int findlca(int u, int v) {
    int f1 = top[u], f2 = top[v];
    while(f1 != f2) {
        if(dep[f1] > dep[f2]) {
            std::swap(f1, f2);
            std::swap(u, v);
        }
        v = fa[f2];
        f2 = top[v];
    }
    if(dep[u] < dep[v]) return u; else return v;
}

DFS序

虽然划分了轻重链,但是只有这些信息我们也只能找个LCA。如果想维护链上信息,我们可以使用DFS序。在dfs2的时候把前序遍历的DFS序处理出来,连续的一段代表一个子树,而由于优先遍历重链,重链也是连续的一段。这样,我们就可以想办法套数据结构来统计链上信息了。
下面是统计DFS序版本的dfs2。

void dfs2(int u, int tp) {
    top[u] = tp;
    pos[u] = stn++;
    if(son[u] != -1) {
        dfs2(son[u], tp);
    }
    for(int e = head[u]; e != -1; e = gra[e].nxt) {
        int v = gra[e].to;
        if(v == son[u] || v == fa[u]) continue;
        dfs2(v, v);
    }
}

例题1:【P3379】【模板】最近公共祖先(LCA) – 洛谷

LCA在线查询,树剖疾如岛风哦~

// Code by KSkun, 2017/11
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm> 

struct io {
    char buf[1 << 26], *s;

    io() {
        fread(s = buf, 1, 1 << 26, stdin);
    }

    inline int read() {
        register int res = 0;
        while(*s < '0' || *s > '9') s++;
        while(*s >= '0' && *s <= '9') res = res * 10 + *s++ - '0';
        return res;
    }
} ip;

#define read ip.read

int n, m, s, x, y, a, b;

struct Edge {
    int to, nxt;
} gra[1000005];
int head[500005], tot = 0;

inline void aedge(int u, int v) {
    gra[tot].nxt = head[u];
    gra[tot].to = v;
    head[u] = tot++;
}

int dep[500005], siz[500005], son[500005], fa[500005], top[500005];

void dfs1(int u) {
    int mx = -1;
    for(int i = head[u]; i != -1; i = gra[i].nxt) {
        int v = gra[i].to;
        if(v == fa[u]) continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        siz[u]++;
        dfs1(v);
        siz[u] += siz[v];
        if(siz[v] > mx) {
            mx = siz[v];
            son[u] = v;
        }
    } 
}

void dfs2(int u, int tp) {
    top[u] = tp;
    if(son[u] != 0) {
        dfs2(son[u], tp);
    }
    for(int i = head[u]; i != -1; i = gra[i].nxt) {
        int v = gra[i].to;
        if(v == son[u] || v == fa[u]) continue;
        dfs2(v, v);
    }
}

int findlca(int u, int v) {
    int f1 = top[u], f2 = top[v];
    while(f1 != f2) {
        if(dep[f1] > dep[f2]) {
            std::swap(f1, f2);
            std::swap(u, v);
        }
        v = fa[f2];
        f2 = top[v];
    }
    if(dep[u] < dep[v]) return u; else return v;
}

int main() {
    memset(head, -1, sizeof head);
    n = read(); m = read(); s = read();
    for(int i = 0; i < n - 1; i++) {
        x = read(); y = read();
        aedge(x, y);
        aedge(y, x);
    }
    dfs1(s);
    dfs2(s, s);
    for(int i = 0; i < m; i++) {
        a = read(); b = read();
        printf("%d\n", findlca(a, b));
    }
    return 0;
}

例题2:【P3384】【模板】树链剖分 – 洛谷

线段树维护链信息和子树信息w

// Code by KSkun, 2017/10
#include <cstdio>
#include <cctype> 
#include <cstring>
typedef long long LL;

inline int read() {
    register int res = 0;
    register bool neg = false;
    char c = '*';
    while(!isdigit(c)) {
        if(c == '-') neg = true;
        c = getchar();
    }
    while(isdigit(c)) res *= 10, res += c - '0', c = getchar();
    if(neg) res *= -1;
    return res;
}

int n, m, r, p;

// Seg Tree
#define lch o << 1
#define rch (o << 1) | 1
#define mid ((l + r) >> 1)

LL a[100005], tree[400005], lazy[400005];

inline void build(int o, int l, int r) {
    if(l == r) {
        tree[o] = a[l];
        return;
    }
    build(lch, l, mid);
    build(rch, mid + 1, r);
    tree[o] = tree[lch] + tree[rch];
}

inline void pushdown(int o, int l, int r) {
    if(lazy[o] != 0) {
        lazy[lch] += lazy[o];
        lazy[rch] += lazy[o];
        tree[lch] += lazy[o] * (mid - l + 1);
        tree[rch] += lazy[o] * (r - mid);
        lazy[o] = 0;
    }
}

inline void add(int o, int l, int r, int ll, int rr, int v) {
    if(l >= ll && r <= rr) {
        tree[o] += v * (r - l + 1);
        lazy[o] += v;
        return;
    }
    pushdown(o, l, r);
    if(ll <= mid) add(lch, l, mid, ll, rr, v);
    if(rr > mid) add(rch, mid + 1, r, ll, rr, v);
    tree[o] = tree[lch] + tree[rch];
}

inline LL query(int o, int l, int r, int ll, int rr) {
    if(l >= ll && r <= rr) { 
        return tree[o];
    } 
    pushdown(o, l, r);
    LL sum = 0;
    if(ll <= mid) sum += query(lch, l, mid, ll, rr);
    if(rr > mid) sum += query(rch, mid + 1, r, ll, rr);
    return sum;
} 

// Poufen

struct Edge {
    int to, nxt;
} gra[200005];
int head[100005], tot = 0;

inline void addedge(int u, int v) {
    gra[tot].to = v;
    gra[tot].nxt = head[u];
    head[u] = tot++;
}

int dep[100005], siz[100005], son[100005], pos[100005], 
    fa[100005], top[100005], val[100005], stn = 1;

void dfs1(int u) {
    int e = head[u], mx = -1;
    while(e != -1) {
        int v = gra[e].to;
        if(v == fa[u]) {
            e = gra[e].nxt;
            continue;
        }
        dep[v] = dep[u] + 1;
        fa[v] = u;
        siz[u]++;
        dfs1(v);
        siz[u] += siz[v];
        if(siz[v] > mx) {
            son[u] = v;
            mx = siz[v];
        }
        e = gra[e].nxt;
    } 
}

void dfs2(int u, int tp) {
    top[u] = tp;
    pos[u] = stn++;
    if(son[u] != -1) {
        dfs2(son[u], tp);
    }
    int e = head[u];
    while(e != -1) {
        int v = gra[e].to;
        if(v == son[u] || v == fa[u]) {
            e = gra[e].nxt;
            continue;
        }
        dfs2(v, v);
        e = gra[e].nxt;
    }
}

inline void addpath(int x, int y, int val) {
    int fx = top[x], fy = top[y];
    while(fx != fy) {
        if(dep[fx] >= dep[fy]) {
            add(1, 1, n, pos[fx], pos[x], val);
            x = fa[fx];
        } else {
            add(1, 1, n, pos[fy], pos[y], val);
            y = fa[fy];
        }
        fx = top[x], fy = top[y];
    }
    if(pos[x] <= pos[y] || x == y) {
        add(1, 1, n, pos[x], pos[y], val);
    } else {
        add(1, 1, n, pos[y], pos[x], val);
    }
}

inline LL querypath(int x, int y) {
    LL res = 0;
    int fx = top[x], fy = top[y];
    while(fx != fy) {
        if(dep[fx] >= dep[fy]) {
            res += query(1, 1, n, pos[fx], pos[x]);
            x = fa[fx];
        } else {
            res += query(1, 1, n, pos[fy], pos[y]);
            y = fa[fy];
        }
        fx = top[x], fy = top[y];
    }
    if(pos[x] <= pos[y] || x == y) {
        res += query(1, 1, n, pos[x], pos[y]);
    } else {
        res += query(1, 1, n, pos[y], pos[x]);
    }
    return res;
}

inline void addsubt(int x, int v) {
    add(1, 1, n, pos[x], pos[x] + siz[x], v);
}

inline LL querysubt(int x) {
    return query(1, 1, n, pos[x], pos[x] + siz[x]);
}

int op, x, y, z;

int main() {
    n = read(), m = read(), r = read(), p = read();
    memset(head, -1, sizeof head);
    memset(son, -1, sizeof son);
    memset(fa, -1, sizeof fa);
    for(int i = 1; i <= n; i++) {
        val[i] = read();
    }
    for(int i = 0; i < n - 1; i++) {
        x = read(), y = read();
        addedge(x, y);
        addedge(y, x);
    }
    dep[r] = 0;
    fa[r] = r;
    dfs1(r);
    top[r] = r;
    dfs2(r, r);
    for(int i = 1; i <= n; i++) {
        a[pos[i]] = val[i];
    }
    build(1, 1, n);
    for(int i = 0; i < m; i++) {
        op = read();
        if(op == 1) {
            x = read(), y = read(), z = read();
            addpath(x, y, z);
        } else if(op == 2) {
            x = read(), y = read();
            printf("%lld\n", querypath(x, y) % p);
        } else if(op == 3) {
            x = read(), z = read();
            addsubt(x, z);
        } else if(op == 4) {
            x = read();
            printf("%lld\n", querysubt(x) % p);
        }
    }
    return 0;
}


发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据