[BZOJ1468]Tree 题解
题目地址:洛谷:【P4178】Tree – 洛谷、BZOJ:Problem 1468. — Tree
题目描述
给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K
输入输出格式
输入格式:
N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k
输出格式:
一行,有多少对点之间的距离小于等于k
输入输出样例
输入样例#1:
7 1 6 13 6 3 9 3 5 7 4 1 3 2 4 20 4 7 2 10
输出样例#1:
5
题解
点分治的板子。下面只说统计答案的思路。
我们考虑每一对儿子,计算其经过根的路径。如果按照【P3806】【模板】点分治1 – 洛谷这道题O(n^2)地统计答案可能有些吃力,我们要把统计答案改为O(n)算法,即维护两个指针l和r来拼出答案。但是这样计算会有不合法的情况,即两个儿子在同一子树内。那么我们考虑递归到子树里面去把这些不合法的答案减掉。递归下去以后还要把所有儿子的距离加上子树-根这条边的边权的2倍。
代码
// Code by KSkun, 2018/3
#include <cstdio>
#include <vector>
#include <algorithm>
typedef long long LL;
inline char fgc() {
static char buf[100000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1++;
}
inline LL readint() {
register LL res = 0, neg = 1;
char c = fgc();
while(c < '0' || c > '9') {
if(c == '-') neg = -1;
c = fgc();
}
while(c >= '0' && c <= '9') {
res = res * 10 + c - '0';
c = fgc();
}
return res * neg;
}
const int MAXN = 40005, INF = 1e9;
int n, m, ut, vt, wt, k, rt, res = 0;
int siz[MAXN], dep[MAXN], msz[MAXN], sum;
bool vis[MAXN];
struct Edge {
int to, w;
};
std::vector<Edge> gra[MAXN];
int ans[MAXN], atot = 0;
inline void getroot(int u, int fa) {
siz[u] = 1; msz[u] = 0;
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i].to;
if(vis[v] || v == fa) continue;
getroot(v, u);
siz[u] += siz[v];
msz[u] = std::max(msz[u], siz[v]);
}
msz[u] = std::max(msz[u], sum - siz[u]);
if(msz[u] < msz[rt]) rt = u;
}
inline void caldep(int u, int fa) {
ans[atot++] = dep[u];
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i].to, w = gra[u][i].w;
if(vis[v] || v == fa) continue;
dep[v] = dep[u] + w;
caldep(v, u);
}
}
inline int work(int u, int w) {
atot = 0;
dep[u] = w;
caldep(u, 0);
std::sort(ans, ans + atot);
int l = 0, r = atot - 1;
int resw = 0;
while(l < r) {
if(ans[l] + ans[r] <= k) {
resw += r - l;
l++;
} else {
r--;
}
}
return resw;
}
inline void dfs(int u) {
res += work(u, 0);
vis[u] = true;
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i].to, w = gra[u][i].w;
if(vis[v]) continue;
res -= work(v, w);
rt = 0;
sum = siz[v];
getroot(v, 0);
dfs(rt);
}
}
int main() {
n = readint();
for(int i = 1; i < n; i++) {
ut = readint(); vt = readint(); wt = readint();
gra[ut].push_back(Edge {vt, wt});
gra[vt].push_back(Edge {ut, wt});
}
k = readint();
rt = 0;
msz[0] = INF;
sum = n;
getroot(1, 0);
dfs(rt);
printf("%d", res);
return 0;
}