[洛谷2664]树上游戏 题解
题目地址:洛谷:【P2664】树上游戏 – 洛谷
题目描述
lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义s(i,j) 为i 到j 的颜色数量。以及
sum_i = \sum_{j=1}^n s(i, j)
现在他想让你求出所有的sum[i]
输入输出格式
输入格式:
第一行为一个整数n,表示树节点的数量
第二行为n个整数,分别表示n个节点的颜色c[1],c[2]……c[n]
接下来n-1行,每行为两个整数x,y,表示x和y之间有一条边
输出格式:
输出n行,第i行为sum[i]
输入输出样例
输入样例#1:
5 1 2 3 2 3 1 2 2 3 2 4 1 5
输出样例#1:
10 9 11 9 12
说明
sum[1]=s(1,1)+s(1,2)+s(1,3)+s(1,4)+s(1,5)=1+2+3+2+2=10
sum[2]=s(2,1)+s(2,2)+s(2,3)+s(2,4)+s(2,5)=2+1+2+1+3=9
sum[3]=s(3,1)+s(3,2)+s(3,3)+s(3,4)+s(3,5)=3+2+1+2+3=11
sum[4]=s(4,1)+s(4,2)+s(4,3)+s(4,4)+s(4,5)=2+1+2+1+3=9
sum[5]=s(5,1)+s(5,2)+s(5,3)+s(5,4)+s(5,5)=2+3+3+3+1=12
对于40%的数据,n<=2000
对于100%的数据,1<=n,c[i]<=10^5
题解
参考资料:题解 P2664 【树上游戏】 – Salamander 的博客 – 洛谷博客
超麻烦的点分治,说实话我在写这篇博文的时候自己都不是很懂。
对于每一个重心,我们先计算其子树内点对它答案的贡献,该贡献即为每种颜色在每条路径第一次出现的位置对应的子树大小之和。我们把子树中每种颜色在每条路径第一次出现的位置的子树大小记为cnt数组,该数组的意义实际上就是包含该种颜色的路径数,则子树内点对重心的贡献就是对cnt数组求和。
接着考虑计算经过重心的路径的贡献。对于每个单独的子树,我们可以求出类似上述cnt数组定义的ct数组。对于一个出现在重心到子树节点x的路上上的颜色col,它会与除了该子树以外的其他到重心路径上无该颜色的点组成一条没有被计算过的路径,因此该颜色会对该点的sum数组产生(siz[rt]-siz[v])-(cnt[col]-ct[col])这么多贡献,同时该颜色也会对该点的子树产生贡献,因此要把这个贡献给传递到子树中去。另外,还有可能在别的子树出现过的颜色对该点的贡献,实际上就是对cnt[col]-ct[col]求和。
总复杂度O(n \log n)。
代码
// Code by KSkun, 2018/6
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#include <vector>
typedef long long LL;
inline char fgc() {
static char buf[100000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2)
? EOF : *p1++;
}
inline LL readint() {
register LL res = 0, neg = 1;
register char c = fgc();
while(!isdigit(c)) {
if(c == '-') neg = -1;
c = fgc();
}
while(isdigit(c)) {
res = (res << 1) + (res << 3) + c - '0';
c = fgc();
}
return res * neg;
}
const int MAXN = 100005;
std::vector<int> gra[MAXN];
int n, c[MAXN];
int siz[MAXN], rt, rtsiz;
bool vis[MAXN];
inline void findrt(int u, int fa, int tot) {
siz[u] = 1; int mxsiz = 0;
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i];
if(vis[v] || v == fa) continue;
findrt(v, u, tot);
siz[u] += siz[v];
mxsiz = std::max(mxsiz, siz[v]);
}
mxsiz = std::max(mxsiz, tot - siz[u]);
if(mxsiz < rtsiz) {
rt = u; rtsiz = mxsiz;
}
}
inline void calsiz(int u, int fa) {
siz[u] = 1;
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i];
if(vis[v] || v == fa) continue;
calsiz(v, u);
siz[u] += siz[v];
}
}
int cnt[MAXN], cct[MAXN], ct[MAXN], col[MAXN], cl[MAXN], top, tp, tot, path;
LL sum[MAXN];
int has[MAXN];
bool exist[MAXN];
inline void dfs(int u, int fa, int *cnt) {
if(!exist[c[u]]) {
col[++top] = c[u]; exist[c[u]] = true;
}
if(++has[c[u]] == 1) cnt[c[u]] += siz[u];
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i];
if(vis[v] || v == fa) continue;
dfs(v, u, cnt);
}
has[c[u]]--;
}
inline void modify(int u, int fa, LL lst) {
LL tag = lst;
if(++has[c[u]] == 1) tag += path - cnt[c[u]];
sum[u] += tot + tag;
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i];
if(vis[v] || v == fa) continue;
modify(v, u, tag);
}
has[c[u]]--;
}
inline void calc(int u) {
calsiz(u, 0);
top = tot = 0;
dfs(u, 0, cnt);
for(int i = 1; i <= top; i++) {
exist[col[i]] = false;
}
tp = top;
for(int i = 1; i <= top; i++) {
tot += cnt[cl[i] = col[i]];
cct[col[i]] = cnt[col[i]];
}
sum[u] += tot;
int temp = tot;
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i];
if(vis[v]) continue;
has[c[u]] = true; top = 0;
dfs(v, u, ct);
has[c[u]] = false;
for(int j = 1; j <= top; j++) {
exist[col[j]] = false;
}
cnt[c[u]] -= siz[v];
tot -= siz[v];
for(int j = 1; j <= top; j++) {
cnt[col[j]] -= ct[col[j]];
tot -= ct[col[j]];
}
path = siz[u] - siz[v];
modify(v, u, 0);
cnt[c[u]] += siz[v];
tot = temp;
for(int j = 1; j <= top; j++) {
cnt[col[j]] = cct[col[j]];
ct[col[j]] = 0;
}
}
for(int i = 1; i <= tp; i++) {
cnt[cl[i]] = 0;
}
vis[u] = true;
}
inline void divide(int u) {
calc(u);
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i];
if(vis[v]) continue;
rt = 0; rtsiz = n; findrt(v, u, siz[v]);
divide(rt);
}
}
int main() {
n = readint();
for(int i = 1; i <= n; i++) {
c[i] = readint();
}
for(int i = 1, u, v; i < n; i++) {
u = readint(); v = readint();
gra[u].push_back(v); gra[v].push_back(u);
}
rt = 0; rtsiz = n; findrt(1, 0, n);
divide(rt);
for(int i = 1; i <= n; i++) {
printf("%lld\n", sum[i]);
}
return 0;
}