[BZOJ3091]城市旅行 题解
题目地址:BZOJ:Problem 3091. — 城市旅行
题目描述
输入输出格式
输入格式:
输出格式:
输入输出样例
输入样例#1:
4 5 1 3 2 5 1 2 1 3 2 4 4 2 4 1 2 4 2 3 4 3 1 4 1 4 1 4
输出样例#1:
16/3 6/1
说明
对于所有数据满足 1<=N<=50,000 1<=M<=50,000 1<=Ai<=10^6 1<=D<=100 1<=U,V<=N
题解
参考资料:BZOJ 3091 城市旅行 Link-Cut-Tree – CSDN博客
其实本题是[HAOI2012]高速公路一题扩展到树上的做法。由于树需要动态维护,自然会想到使用LCT。
我们考虑一条树链上的期望应该如何计算,可以求出树链上每个点对期望的分子的贡献,对于a_i,它的贡献是i(n-i+1),即我们要求的分子是\sum_{i=1}^n a_i \cdot i \cdot (n-i+1),我们得想办法在LCT上把这个值维护出来。
于是我们就想到了,能不能先求出子树对应树链的答案,然后再怎么样拼起来呢?
我们观察一下结果吧,假设左子树树链长为4,右子树长为2
想求的东西 1*7*a1 + 2*6*a2 + 3*5*a3 + 4*4*a4 + 5*3*a5 + 6*2*a6 + 7*1*a7 两个子树的答案加起来 1*4*a1 + 2*3*a2 + 3*2*a3 + 4*1*a4 + + 1*2*a6 + 2*1*a7 上式减下式 1*3*a1 + 2*3*a2 + 3*3*a3 + 4*3*a4 + 5*3*a5 + 5*2*a6 + 5*1*a7
a5是树根处的元素可以单独算,3其实是右子树大小+1,5则是左子树大小+1,我们考虑同时维护一个lsum = \sum_{i=1}^n a_i \cdot i和一个rsum = \sum_{i=1}^n a_i \cdot (n-i+1),就可以借助这些来求整条树链的答案了,即
ans = ans_l + ans_r + (siz_l+1)(siz_r+1)a_{root} + (siz_r+1)(lsum_l) + (siz_l+1)(rsum_r)
而lsum、rsum的合并很简单,只需要把子树元素和乘上另一侧子树的个数+1就好,元素和可以顺便维护一下。
此外,本题还要求一个树链加法标记,如何应用这个加法标记到我们维护的信息上就会是问题。lsum、rsum需要加上add \cdot \sum_{i=1}^n i,比较好办,而期望的分子需要加上add \cdot \sum_{i=1}^n i(n-i+1),这个并不好办,不过我们有数学方法得到这个求和的公式,最后得出的结果是\sum_{i=1}^n i(n-i+1) = \frac{n(n+1)(n+2)}{6}。
万事大吉了,剩下的就是把上面的东西码进去。
代码
// Code by KSkun, 2018/6
#include <cstdio>
#include <cctype>
#include <algorithm>
#include <vector>
typedef long long LL;
inline char fgc() {
static char buf[100000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2)
? EOF : *p1++;
}
inline LL readint() {
register LL res = 0, neg = 1; register char c = fgc();
for (; !isdigit(c); c = fgc()) if (c == '-') neg = -1;
for (; isdigit(c); c = fgc()) res = (res << 1) + (res << 3) + c - '0';
return res * neg;
}
const int MAXN = 50005;
struct Node {
int ch[2], fa; LL siz, val, lsum, rsum, sum, exp;
bool rev; LL add;
} tr[MAXN];
inline bool isleft(int p) {
return tr[tr[p].fa].ch[0] == p;
}
inline bool isroot(int p) {
return tr[tr[p].fa].ch[0] != p && tr[tr[p].fa].ch[1] != p;
}
inline void update(int p) {
tr[p].siz = tr[tr[p].ch[0]].siz + tr[tr[p].ch[1]].siz + 1;
tr[p].sum = tr[tr[p].ch[0]].sum + tr[tr[p].ch[1]].sum + tr[p].val;
tr[p].lsum = tr[tr[p].ch[0]].lsum + tr[p].val * (tr[tr[p].ch[0]].siz + 1)
+ tr[tr[p].ch[1]].lsum + tr[tr[p].ch[1]].sum * (tr[tr[p].ch[0]].siz + 1);
tr[p].rsum = tr[tr[p].ch[0]].rsum + tr[tr[p].ch[0]].sum * (tr[tr[p].ch[1]].siz + 1)
+ tr[p].val * (tr[tr[p].ch[1]].siz + 1) + tr[tr[p].ch[1]].rsum;
tr[p].exp = tr[tr[p].ch[0]].exp + tr[tr[p].ch[1]].exp
+ tr[tr[p].ch[0]].lsum * (tr[tr[p].ch[1]].siz + 1)
+ tr[tr[p].ch[1]].rsum * (tr[tr[p].ch[0]].siz + 1)
+ (tr[tr[p].ch[0]].siz + 1) * (tr[tr[p].ch[1]].siz + 1) * tr[p].val;
}
inline void reverse(int p) {
std::swap(tr[p].ch[0], tr[p].ch[1]);
std::swap(tr[p].lsum, tr[p].rsum);
tr[p].rev ^= 1;
}
inline void add(int p, LL v) {
tr[p].val += v;
tr[p].sum += v * tr[p].siz;
tr[p].lsum += v * tr[p].siz * (tr[p].siz + 1) / 2;
tr[p].rsum += v * tr[p].siz * (tr[p].siz + 1) / 2;
tr[p].exp += v * tr[p].siz * (tr[p].siz + 1) * (tr[p].siz + 2) / 6;
tr[p].add += v;
}
inline void pushdown(int p) {
if(tr[p].rev) {
if(tr[p].ch[0]) reverse(tr[p].ch[0]);
if(tr[p].ch[1]) reverse(tr[p].ch[1]);
tr[p].rev ^= 1;
}
if(tr[p].add) {
if(tr[p].ch[0]) add(tr[p].ch[0], tr[p].add);
if(tr[p].ch[1]) add(tr[p].ch[1], tr[p].add);
tr[p].add = 0;
}
}
int sta[MAXN], stop;
inline void pushto(int p) {
stop = 0;
while(!isroot(p)) {
sta[stop++] = p; p = tr[p].fa;
}
sta[stop++] = p;
while(stop) {
pushdown(sta[--stop]);
}
}
inline void rotate(int p) {
bool t = !isleft(p); int fa = tr[p].fa, ffa = tr[fa].fa;
tr[p].fa = ffa; if(!isroot(fa)) tr[ffa].ch[!isleft(fa)] = p;
tr[fa].ch[t] = tr[p].ch[!t]; tr[tr[fa].ch[t]].fa = fa;
tr[p].ch[!t] = fa; tr[fa].fa = p;
update(fa);
}
inline void splay(int p) {
pushto(p);
for(int fa = tr[p].fa; !isroot(p); rotate(p), fa = tr[p].fa) {
if(!isroot(fa)) rotate(isleft(fa) == isleft(p) ? fa : p);
}
update(p);
}
inline void access(int p) {
for(int q = 0; p; q = p, p = tr[p].fa) {
splay(p); tr[p].ch[1] = q; update(p);
}
}
inline void makert(int p) {
access(p); splay(p); reverse(p);
}
inline int findrt(int p) {
access(p); splay(p);
while(tr[p].ch[0]) p = tr[p].ch[0];
return p;
}
inline void split(int u, int v) {
makert(u); access(v); splay(v);
}
inline void link(int u, int v) {
if(findrt(u) == findrt(v)) return;
split(u, v); tr[u].fa = v;
}
inline void cut(int u, int v) {
split(u, v);
if(tr[v].ch[0] != u || tr[u].ch[1]) return;
tr[u].fa = tr[v].ch[0] = 0; update(v);
}
int n, m;
std::vector<int> gra[MAXN];
inline void dfs(int u, int fa) {
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i];
if(v == fa) continue;
tr[v].fa = u;
dfs(v, u);
}
}
inline LL gcd(LL x, LL y) {
if(!y) return x;
return gcd(y, x % y);
}
int main() {
n = readint(); m = readint();
for(int i = 1; i <= n; i++) {
tr[i].val = readint(); update(i);
}
for(int i = 1, a, b; i < n; i++) {
a = readint(); b = readint();
gra[a].push_back(b); gra[b].push_back(a);
}
dfs(1, 0);
int op, u, v; LL d;
while(m--) {
op = readint(); u = readint(); v = readint();
if(op == 1) {
cut(u, v);
} else if(op == 2) {
link(u, v);
} else if(op == 3) {
d = readint(); if(findrt(u) != findrt(v)) continue;
split(u, v); add(v, d);
} else {
if(findrt(u) != findrt(v)) {
puts("-1"); continue;
}
split(u, v);
LL up = tr[v].exp, down = (tr[v].siz + 1) * tr[v].siz / 2, g = gcd(up, down);
printf("%lld/%lld\n", up / g, down / g);
}
}
return 0;
}