[SDOI2011]消耗战 题解
题目地址:洛谷:【P2495】[SDOI2011]消耗战 – 洛谷、BZOJ:Problem 2286. — [Sdoi2011]消耗战
题目描述
在一场战争中,战场由n个岛屿和n-1个桥梁组成,保证每两个岛屿间有且仅有一条路径可达。现在,我军已经侦查到敌军的总部在编号为1的岛屿,而且他们已经没有足够多的能源维系战斗,我军胜利在望。已知在其他k个岛屿上有丰富能源,为了防止敌军获取能源,我军的任务是炸毁一些桥梁,使得敌军不能到达任何能源丰富的岛屿。由于不同桥梁的材质和结构不同,所以炸毁不同的桥梁有不同的代价,我军希望在满足目标的同时使得总代价最小。
侦查部门还发现,敌军有一台神秘机器。即使我军切断所有能源之后,他们也可以用那台机器。机器产生的效果不仅仅会修复所有我军炸毁的桥梁,而且会重新随机资源分布(但可以保证的是,资源不会分布到1号岛屿上)。不过侦查部门还发现了这台机器只能够使用m次,所以我们只需要把每次任务完成即可。
题意简述
给你一棵包含$n$个点的带边权的树,有$m$次询问,每次给你一个大小为$k_i$的点集,你可以花费边权代价切断树上的一些边,使得这个点集中的点与树根$1$不连通,求满足上述条件的最小切边代价。
输入输出格式
输入格式:
第一行一个整数n,代表岛屿数量。
接下来n-1行,每行三个整数u,v,w,代表u号岛屿和v号岛屿由一条代价为c的桥梁直接相连,保证1<=u,v<=n且1<=c<=100000。
第n+1行,一个整数m,代表敌方机器能使用的次数。
接下来m行,每行一个整数ki,代表第i次后,有ki个岛屿资源丰富,接下来k个整数h1,h2,…hk,表示资源丰富岛屿的编号。
输出格式:
输出有m行,分别代表每次任务的最小代价。
输入输出样例
输入样例#1:
10 1 5 13 1 9 6 2 1 19 2 4 8 2 3 91 5 6 8 7 5 4 7 8 31 10 7 9 3 2 10 6 4 5 7 8 3 3 9 4 6
输出样例#1:
12 32 22
说明
【数据规模和约定】
对于10%的数据,2<=n<=10,1<=m<=5,1<=ki<=n-1
对于20%的数据,2<=n<=100,1<=m<=100,1<=ki<=min(10,n-1)
对于40%的数据,2<=n<=1000,m>=1,sigma(ki)<=500000,1<=ki<=min(15,n-1)
对于100%的数据,2<=n<=250000,m>=1,sigma(ki)<=500000,1<=ki<=n-1
题解
如果只有一次询问,我们可以进行一次DFS处理出$mn[u]$代表$u$点到$1$点链上的最小边,显然,断开它是切断$u$和$1$连接的最优方案。我们用$dp[u]$表示切断$u$子树中所有点与$1$的连接的最小代价,则转移为
$$ dp[u] = \sum_{v \in \mathrm{son}(u)} \min \{ dp[v], mn[v] \} $$
这个转移表示要么切断$v$到$u$的这条边,要么切断其子树内的边。DP初值设置为对于每一个点集里的点,$dp[u]$初始为$mn[u]$,这样在DP求和的时候,如果一个父亲自己也是点集中的点,权值也会被计算进去。显然,如果两个点到$1$路径上的最优边是同一条,那么一定可以在这条最优边处向上转移时只计算一次权值,这是由于$\min$的存在。
然而这个算法是$O(n)$的,也就是说,总复杂度为$O(mn)$,并不能通过本题。
我们考虑建立虚树解决这个问题。虚树是指一棵只包含指定点即它们两两间LCA的辅助用树,这样做可以减少树上点的数量,降低DP的复杂度至$O(k_i \log k_i)$。
具体而言,如果我们处理出树的DFS序,那么对DFS序相邻的两点求LCA,一定可以得到所有点两两的LCA,这是由于,在不同子树内的点的LCA可以在处理到子树DFS序区间端点之间的LCA的时候被求出来。但是,LCA和儿子的边的关系并不明确,直接建边似乎是不明智的选择。我们考虑一种特殊的树的遍历序:欧拉序,它是指DFS遍历时,进入DFS栈的时候将该点加入序列,出栈时也加入一次。如果对于上面的所有点及其LCA的集合,分别插入出入栈点用欧拉序排序,就可以利用这一顺序来模拟DFS了,省去了建边的过程。
模拟DFS的时候,只需要在出栈的时候进行操作,这是因为,此时出栈点$u$子树的DP已经计算完毕,只需要去更新它父亲的DP值即可,而出栈后的栈顶元素就是它的父亲,直接按照上面的方程去转移即可。
使用欧拉序的虚树算法的复杂度是单次$O(n \log n)$的,达到这个复杂度的操作是查询LCA和排序。这里的$n$指的是虚树的节点数,它的规模是$O(k_i \log k_i)$的。
如此,总复杂度终于变得科学了不少。不过本题有点卡常,稍微卡一卡也可以卡进去。
代码
// Code by KSkun, 2018/7
#include <cstdio>
#include <cctype>
#include <cstring>
#include <algorithm>
#include <vector>
typedef long long LL;
inline LL min(LL a, LL b) {
return a < b ? a : b;
}
inline char fgc() {
static char buf[100000], *p1 = buf, *p2 = buf;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2)
? EOF : *p1++;
}
inline LL readint() {
register LL res = 0, neg = 1; register char c = fgc();
for(; !isdigit(c); c = fgc()) if(c == '-') neg = -1;
for(; isdigit(c); c = fgc()) res = res * 10 + c - '0';
return res * neg;
}
const int MAXN = 500005;
int n, m;
struct Edge {
int to, w;
};
std::vector<Edge> gra[MAXN];
int anc[MAXN][20], mn[MAXN], dep[MAXN], dfn[MAXN], clk;
void dfs(int u, int fa) {
dfn[u] = ++clk;
for(int i = 0; i < gra[u].size(); i++) {
int v = gra[u][i].to;
if(v == fa) continue;
anc[v][0] = u;
dep[v] = dep[u] + 1;
mn[v] = std::min(mn[u], gra[u][i].w);
dfs(v, u);
}
dfn[u + n] = ++clk;
}
inline void calanc() {
for(int i = 1; i <= 19; i++) {
for(int j = 1; j <= n; j++) {
anc[j][i] = anc[anc[j][i - 1]][i - 1];
}
}
}
inline int querylca(int u, int v) {
if(dep[u] > dep[v]) std::swap(u, v);
int del = dep[v] - dep[u];
for(int i = 19; i >= 0; i--) {
if(del & (1 << i)) v = anc[v][i];
}
if(u == v) return u;
for(int i = 19; i >= 0; i--) {
if(anc[u][i] != anc[v][i]) {
u = anc[u][i];
v = anc[v][i];
}
}
return anc[u][0];
}
inline bool cmp(int a, int b) {
return dfn[a] < dfn[b];
}
LL dp[MAXN];
bool vis[MAXN];
std::vector<int> vec;
int sta[MAXN], stop;
int main() {
n = readint();
for(int i = 1, u, v, w; i < n; i++) {
u = readint(); v = readint(); w = readint();
gra[u].push_back(Edge {v, w});
gra[v].push_back(Edge {u, w});
}
mn[1] = 1e9;
dfs(1, 0);
calanc();
m = readint();
while(m--) {
vec.clear();
int k = readint();
for(int i = 1; i <= k; i++) {
int t = readint();
vis[t] = true; dp[t] = mn[t]; vec.push_back(t);
}
std::sort(vec.begin(), vec.end(), cmp);
for(int i = 1; i < k; i++) {
int l = querylca(vec[i - 1], vec[i]);
if(!vis[l]) {
vis[l] = true;
vec.push_back(l);
}
}
if(!vis[1]) vec.push_back(1);
for(int i = 0; vec[i] <= n; i++) {
vec.push_back(vec[i] + n);
}
std::sort(vec.begin(), vec.end(), cmp);
for(int i = 0; i < vec.size(); i++) {
if(vec[i] <= n) {
sta[stop++] = vec[i];
} else {
int u = sta[--stop];
if(u != 1) {
int fa = sta[stop - 1]; dp[fa] += min(mn[u], dp[u]);
} else {
printf("%lld\n", dp[u]);
}
dp[u] = 0; vis[u] = false;
}
}
}
return 0;
}