【CF671D】Roads in Yusland
题目
题目链接:https://codeforces.com/problemset/problem/671/D
- 给定一棵 \(n\) 个点的以 \(1\) 为根的树。
- 有 \(m\) 条路径 \((x,y)\),保证 \(y\) 是 \(x\) 或 \(x\) 的祖先,每条路径有一个权值。
- 你要在这些路径中选择若干条路径,使它们能覆盖每条边,同时权值和最小。
- \(n,m \le 3 \times 10^5\)。
思路
设 \(f[x]\) 表示覆盖点 \(x\) 子树内所有边以及 \(x\) 与其父亲的边的最小代价。
但是很明显 \(f[x]\) 不能简单转移。因为有可能花更多代价,覆盖 \(x\) 的祖先更多,这种情况是可能最优的。
所以可以对每一个点维护一个堆,存可能的最优解。
考虑点 \(y\) 怎么转移到其父亲 \(x\)。对于 \(y\) 的堆中一个代价为 \(k\) 的方案,合并到 \(x\) 后,其代价应该是 \(k+\sum_{z\in \text{son}(x),z\neq y} f[z]\)。
也就是说,\(y\) 的所有方案只需要同时加上一个常数,然后扔到 \(x\) 的堆里就好了。直接上左偏树,然后需要搞一个子树加的标记。
但是当某一个方案覆盖不到 \(x\) 与其父亲的边的时候,这个方案就需要删掉了。在每次合并完后不断判断堆顶是否需要删掉即可。
新建一个虚根连向 \(1\),再加一条代价为 \(0\) 的路径,最后输出虚根的 \(f\) 即可。
时间复杂度 \(O(n\log m)\)。
代码
#include <bits/stdc++.h>
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
const int N=300010;
int n,m,tot,head[N],rt[N],dep[N];
ll f[N];
bool flag;
vector<pair<int,int> > a[N];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
struct LeftistTree
{
int tot,dis[N],pos[N],lc[N],rc[N];
ll val[N],lazy[N];
int insert(pair<int,int> b)
{
tot++; val[tot]=b.se; pos[tot]=b.fi;
return tot;
}
void pushdown(int x)
{
if (lazy[x])
{
if (lc[x]) val[lc[x]]+=lazy[x],lazy[lc[x]]+=lazy[x];
if (rc[x]) val[rc[x]]+=lazy[x],lazy[rc[x]]+=lazy[x];
lazy[x]=0;
}
}
int merge(int x,int y)
{
if (!x || !y) return x|y;
pushdown(x); pushdown(y);
if (val[x]>val[y] || (val[x]==val[y] && x>y)) swap(x,y);
rc[x]=merge(rc[x],y);
if (dis[rc[x]]>dis[lc[x]]) swap(lc[x],rc[x]);
dis[x]=dis[rc[x]]+1;
return x;
}
int pop(int x)
{
pushdown(x);
return merge(lc[x],rc[x]);
}
}lit;
void dfs(int x,int fa)
{
dep[x]=dep[fa]+1;
for (int i=0;i<(int)a[x].size();i++)
rt[x]=lit.merge(rt[x],lit.insert(a[x][i]));
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs(v,x); f[x]+=f[v];
if (flag) return;
lit.lazy[rt[v]]-=f[v]; lit.val[rt[v]]-=f[v];
rt[x]=lit.merge(rt[x],rt[v]);
}
}
lit.lazy[rt[x]]+=f[x]; lit.val[rt[x]]+=f[x];
while (rt[x] && dep[lit.pos[rt[x]]]>=dep[x])
rt[x]=lit.pop(rt[x]);
if (!rt[x]) { flag=1; return; }
f[x]=lit.val[rt[x]];
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
n++; add(n,1);
for (int i=1,x,y,z;i<=m;i++)
{
scanf("%d%d%d",&x,&y,&z);
a[x].push_back(mp(y,z));
}
a[1].push_back(mp(0,0));
lit.dis[0]=-1; dep[0]=-1;
dfs(n,0);
if (flag) cout<<"-1";
else cout<<f[n];
return 0;
}