洛谷 P5642 - 人造情感(换根 dp)

想起来很轻松,写起来很酸爽的套路题。

默认以 \(1\) 为根。先考虑怎么算单个 \(f(u,v)\),我们定义一个连通块的权值为从该连通块中选出若干条点不相交的路径,选出的路径的权值之和的最大值。那么显然 \(f(u,v)\) 就是整棵树的权值 \(-\) 挖掉 \((u,v)\) 这条路径后各个连通块的权值之和。显然,挖掉 \((u,v)\) 以后,剩余的连通块只有两种类型:要么是某个点为根的子树,要么是全图扣掉某个点为根的子树,预处理这两类连通块的权值后用转换贡献体的方式随便推一推系数即可求出答案。

先考虑计算 \(f_u\) 表示以 \(u\) 为根的子树的权值,这部分相对来说比较 easy,如果 \(u\) 没选,那么直接把 \(u\) 所有儿子的 \(f\) 加起来,否则 \(u\) 一定是这条路径的 LCA,枚举这条路径,最大权值和可以树剖算出来。

其次考虑换根 dp 计算 \(g_u\) 表示全图扣掉以 \(u\) 为根的子树,剩余部分的权值和。如果 \(fa_u\) 没选也很 trivial,直接把剩余子树加起来即可。如果 \(fa_u\) 选了,比较棘手的地方在于你不能去枚举这条路径是什么,否则复杂度会退化到平方。考虑维护一些可能成为最大值的路径,然后计算 \(u\) 某个儿子 \(v\)\(g\) 值时就降这些路径中经过 \((u,v)\) 边的路径暂时从集合中删除。具体来说我们将经过 \(u\) 的路径分为三类:

  • 一个端点在 \(u\),另一个端点在 \(u\) 子树外,这样的路径不经过任何 \(u\) 与某个儿子的边,直接枚举这样的路径加入集合即可,总枚举量线性。
  • LCA 为 \(u\),同理直接枚举即可,总枚举量线性,由于这样的路径最多可能经过两个 \(u\) 与儿子的边,所以要注意在计算这两个儿子的 DP 值时把这条路径的权值暂时删除。
  • LCA 不为 \(u\),且不属于第一类,直接枚举路径总枚举量会达到平方,但是注意到我们只关心这条路径经过了 \(u\) 与哪条儿子的边,因此我们考虑用线段树 + 树剖维护 \(h_e\) 表示经过 \(e\) 这条边的这类路径的贡献的最大值,这样可能计入答案的路径只有 \(O(deg_u)\) 个,总枚举量也降到了线性。

时间复杂度 \(O(n\log^2n)\)

const int MAXN=3e5;
const int LOG_N=19;
const int MOD=998244353;
int n,m,hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int fa[MAXN+5][LOG_N+2],dep[MAXN+5],siz[MAXN+5],wson[MAXN+5];
int top[MAXN+5],dfn[MAXN+5],edt[MAXN+5],rid[MAXN+5],idx;
void dfs1(int x,int f){
	fa[x][0]=f;siz[x]=1;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==f)continue;dep[y]=dep[x]+1;
		dfs1(y,x);siz[x]+=siz[y];if(siz[y]>siz[wson[x]])wson[x]=y;
	}
}
void dfs2(int x,int tp){
	top[x]=tp;rid[dfn[x]=++idx]=x;
	if(wson[x])dfs2(wson[x],tp);
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x][0]||y==wson[x])continue;
		dfs2(y,y);
	}edt[x]=idx;
}
int getlca(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		x=fa[top[x]][0];
	}return (dep[x]<dep[y])?x:y;
}
int get_kanc(int x,int k){
	for(int i=LOG_N;~i;i--)if(k>>i&1)x=fa[x][i];
	return x;
}
int getlst(int x,int y){return get_kanc(y,dep[y]-dep[x]-1);}
struct path{int u,v,w,lc;ll s;}p[MAXN+5];
vector<int>plca[MAXN+5],px[MAXN+5];
ll dp[MAXN+5],dp_out[MAXN+5],A[MAXN+5],S[MAXN+5];
struct fenwick{
	ll t[MAXN+5];
	void add(int x,ll v){for(int i=x;i<=n;i+=(i&(-i)))t[i]+=v;}
	ll query(int x){ll ret=0;for(int i=x;i;i&=(i-1))ret+=t[i];return ret;}
	ll query(int l,int r){return query(r)-query(l-1);}
}T;
struct segtree{
	struct node{int l,r;ll mx;}s[MAXN*4+5];
	void build(int k,int l,int r){
		s[k].l=l;s[k].r=r;if(l==r)return;int mid=l+r>>1;
		build(k<<1,l,mid);build(k<<1|1,mid+1,r);
	}
	ll query(int k,int p){
		if(s[k].l==s[k].r)return s[k].mx;int mid=s[k].l+s[k].r>>1;
		return max(s[k].mx,(p<=mid)?query(k<<1,p):query(k<<1|1,p));
	}
	void makemax(int k,int l,int r,ll v){
		if(l<=s[k].l&&s[k].r<=r)return chkmax(s[k].mx,v),void();
		int mid=s[k].l+s[k].r>>1;
		if(r<=mid)makemax(k<<1,l,r,v);
		else if(l>mid)makemax(k<<1|1,l,r,v);
		else makemax(k<<1,l,mid,v),makemax(k<<1|1,mid+1,r,v);
	}
}segt;
void dfs3(int x){
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x][0])continue;
		dfs3(y);S[x]+=dp[y];
	}
	dp[x]=S[x];
	for(int id:plca[x]){
		int u=p[id].u,v=p[id].v;ll s=p[id].w;
		if(u!=x)s+=S[u];if(v!=x)s+=S[v];
		s+=S[x];
		if(u!=x)s-=dp[getlst(x,u)];
		if(v!=x)s-=dp[getlst(x,v)];
		while(top[u]!=top[v]){
			if(fa[top[u]]<fa[top[v]])swap(u,v);
			s+=T.query(dfn[top[u]],dfn[u]);
			u=fa[top[u]][0];
		}
		if(dep[u]<dep[v])swap(u,v);
		s+=T.query(dfn[v],dfn[u]);
		chkmax(dp[x],s);p[id].s=s;
	}
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x][0])continue;
		A[y]=S[x]-dp[y];T.add(dfn[y],A[y]);
	}
}
void dfs4(int x){
	static vector<ll>del[MAXN+5];
	multiset<ll>st;
	for(int id:plca[x]){
		int u=p[id].u,v=p[id].v;ll val=p[id].s+dp_out[x];
		if(u!=x)del[getlst(x,u)].pb(val);
		if(v!=x)del[getlst(x,v)].pb(val);
		st.insert(val);
	}
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x][0])continue;
		ll mx=segt.query(1,dfn[y]);
		st.insert(mx);del[y].pb(mx);
	}
	for(int id:px[x]){
		int u=p[id].u,v=p[id].v;ll val=p[id].s+dp_out[p[id].lc];
		if(dfn[x]<dfn[u]&&dfn[u]<=edt[x])continue;
		if(dfn[x]<dfn[v]&&dfn[v]<=edt[x])continue;
		st.insert(val);
	}
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x][0])continue;
		dp_out[y]=dp_out[x]+S[x]-dp[y];
		for(ll v:del[y])st.erase(st.find(v));
		if(!st.empty())chkmax(dp_out[y],(*st.rbegin())-dp[y]);
		for(ll v:del[y])st.insert(v);
	}
	for(int id:plca[x]){
		int u=p[id].u,v=p[id].v;ll val=p[id].s+dp_out[x];
		while(top[u]!=top[v]){
			if(fa[top[u]]<fa[top[v]])swap(u,v);
			segt.makemax(1,dfn[top[u]],dfn[u],val);
			u=fa[top[u]][0];
		}
		if(dep[u]<dep[v])swap(u,v);
		segt.makemax(1,dfn[v],dfn[u],val);
	}
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==fa[x][0])continue;
		dfs4(y);
	}
}
int main(){
//	freopen("in.txt","r",stdin);
	freopen("emotion.in","r",stdin);
	freopen("emotion.out","w",stdout);
	scanf("%d%d",&n,&m);
	for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),adde(u,v),adde(v,u);
	dfs1(1,0);dfs2(1,1);
	for(int i=1;i<=LOG_N;i++)for(int j=1;j<=n;j++)
		fa[j][i]=fa[fa[j][i-1]][i-1];
	for(int i=1;i<=m;i++){
		scanf("%d%d%d",&p[i].u,&p[i].v,&p[i].w);
		p[i].lc=getlca(p[i].u,p[i].v);
		plca[p[i].lc].pb(i);px[p[i].u].pb(i);px[p[i].v].pb(i);
	}
	dfs3(1);segt.build(1,1,n);dfs4(1);
	int ans=1ll*dp[1]%MOD*n%MOD*n%MOD;
	for(int i=1;i<=n;i++){
		ll cnt=1ll*siz[i]*siz[i];
		for(int e=hd[i];e;e=nxt[e]){
			int j=to[e];if(j==fa[i][0])continue;
			cnt-=1ll*siz[j]*siz[j];
		}ans=(ans-cnt%MOD*(dp_out[i]%MOD)%MOD+MOD)%MOD;
		ans=(ans+2ll*dp[i]%MOD*siz[i]%MOD*(n-siz[i]))%MOD;
	}
	for(int i=1;i<=n;i++){
		ll cnt=1ll*n*n;
		for(int e=hd[i];e;e=nxt[e]){
			int j=to[e];if(j==fa[i][0])continue;
			cnt-=1ll*siz[j]*siz[j];
		}cnt-=1ll*(n-siz[i])*(n-siz[i]);
		ans=(ans-cnt%MOD*(S[i]%MOD)%MOD+MOD)%MOD;
	}
	printf("%d\n",ans);
	return 0;
}
/*
9 8
1 2
1 3
2 4
2 5
5 6
5 7
5 8
3 9
6 7 3
6 8 5
8 8 3
4 8 10
2 2 1
4 4 4
1 9 6
2 9 4
*/
posted @ 2023-03-31 13:50  tzc_wk  阅读(114)  评论(0)    收藏  举报