Loading

P7163 [COCI2020-2021#2] Svjetlo

题意

给你一棵点权是 \(0/1\) 的树,你可以从任意一点开始,走到任意一点结束,每到达一个点,都要翻转当前的点权。给定初始的点权,求使得整棵树的点权都变成 \(1\) 的最短路径长度。

Solution

乍一看以为是个换根。。。看题解发现自己 naive 了。

对于求树上最优路径的问题,可以考虑两端是否在一棵子树中。我们令 \(dp_{i,0/1/2,0/1}\) 表示在 \(i\) 的子树中有整个路径的 \(j\) 个端点,并且走完之后 \(x\)\(0/1\) 的点权,此时让整个子树都变成 \(1\) 的最短路径。

为了方便转移,这里的路径采取左闭右开,也就是说在 \(0,2\) 状态中,我们只算入进入根的,不算入出根的。

对于没有端点在子树内的,说明它是从外头进来,然后在子树里捣鼓一圈后又出去。这可以直接从儿子的状态转移过来。然后考虑从当前根走入儿子,然后又出来,回到根,这样根和儿子的状态都变了。那如果想儿子状态翻转,最优的策略就是再从根走儿子,然后回到根。即:

\[dp'_{i,0,0}=\min(dp_{i,0,1}+dp_{s,0,0}+2,dp_{i,0,0}+dp_{s,0,1}+4)\\ dp'_{i,0,1}=\min(dp_{i,0,0}+dp_{s,0,0}+2,dp_{i,0,1}+dp_{s,0,1}+4) \]

对于只有一个端点在子树内的,说明它是从外头进来然后不出去了。这也就意味着,当前根的某一个子树中有一个是含有一个端点的。那么合并一个子树的时候,可能这个端点是在新的子树中的,或者是在原来的子树中的,取最小值即可。即:

\[dp'_{i,1,0}=\min\{dp_{i,0,1}+dp_{s,1,1}+1,dp_{i,0,0}+dp_{s,1,0}+3,dp_{i,1,1}+dp_{s,0,0}+2,dp_{i,1,0}+dp_{s,0,1}+4\}\\ dp'_{i,1,1}=\min\{dp_{i,0,0}+dp_{s,1,1}+1,dp_{i,0,1}+dp_{s,1,0}+3,dp_{i,1,0}+dp_{s,0,0}+2,dp_{i,1,1}+dp_{s,0,1}+4\} \]

对于两个端点都在子树内的,需要更多的讨论。合并两树 \(i,s\) 的时候,有:

  1. \(i\) 中两个端点,\(s\) 中没有端点;
  2. \(s\) 中两个端点,\(i\) 中没有端点;
  3. \(i,s\) 中各一个端点。

对于上面的我们各写出转移是这样的:

  1. \[dp'_{i,2,0}=\min(dp_{i,2,1}+dp_{s,0,0}+2,dp_{i,2,0}+dp_{s,0,1}+4)\\dp'_{i,2,1}=\min(dp_{i,2,0}+dp_{s,0,0}+2,dp_{i,2,1}+dp_{s,0,1}+4) \]

  2. \[dp'_{i,2,0}=\min(dp_{i,0,1}+dp_{s,2,0}+2,dp_{i,0,0}+dp_{s,2,1}+4)\\dp'_{i,2,1}=\min(dp_{i,0,0}+dp_{s,2,0}+2,dp_{i,0,1}+dp_{s,2,1}+4) \]

  3. \[dp'_{i,2,0}=\min(dp_{i,1,0}+dp_{s,1,1},dp_{i,1,1}+dp_{s,1,0}+2)\\dp'_{i,2,1}=\min(dp_{i,1,1}+dp_{s,1,1},dp_{i,1,0}+dp_{s,1,0}+2) \]

然后每类取 \(\min\) 就可以了。

对于每个节点,其子树内的某一个端点可能在当前子树的根上,但是我们在上面并没有计入。对于子树中有一个端点,并且它在根的情况,有:

\[dp'_{i,1,0}=\min(dp_{i,1,0},dp_{i,0,1}+1)\\ dp'_{i,1,1}=\min(dp_{i,1,1},dp_{i,0,0}+1)\\ \]

对于子树中有两个端点,并且有在根的情况,有:

\[dp'_{i,2,0}=\min(dp_{i,2,0},dp_{i,1,0})\\ dp'_{i,2,1}=\min(dp_{i,2,1},dp_{i,1,1}) \]

边界条件:\(dp_{i,0,c}=0\)。最终答案就是 \(dp_{1,2,1}\)

细节:

要从初值为 \(0\) 的点开始遍历,并且如果一个子树中都是 \(1\),那么不用进入这个子树,否则会多算。

Code

#include<bits/stdc++.h>
#define ll long long
#define inf (1<<30)
#define INF (1ll<<60)
#define pb emplace_back
#define pii pair<int,int>
#define mkp make_pair
#define fi first
#define se second
#define all(a) (a).begin(),(a).end()
#define siz(a) (int)(a).size()
#define clr(a) memset(a,0,sizeof(a))
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
#define pt(a) cerr<<#a<<'='<<a<<' '
#define pts(a) cerr<<#a<<'='<<a<<'\n'
//#define int long long
using namespace std;
const int MAXN=5e5+10;
int c[MAXN],dp[MAXN][3][2],tmp[3][2],flag[MAXN];
vector<int> e[MAXN];
void pdfs(int x,int fa){
	flag[x]=c[x];
	for(int s:e[x]){
		if(s==fa) continue;
		pdfs(s,x);
		flag[x]&=flag[s];
	}
}
void dfs(int x,int fa){
	dp[x][0][c[x]]=0;
	for(int s:e[x]){
		if(s==fa) continue;
		if(flag[s]) continue;
		dfs(s,x);
		rep(i,0,2) rep(j,0,1) tmp[i][j]=dp[x][i][j];
		memset(dp[x],0x3f,sizeof(dp[x]));//Important
		//In case trans from itself
		dp[x][0][0]=min(tmp[0][1]+dp[s][0][0]+2,tmp[0][0]+dp[s][0][1]+4);
		dp[x][0][1]=min(tmp[0][0]+dp[s][0][0]+2,tmp[0][1]+dp[s][0][1]+4);
		
		dp[x][1][0]=min(min(tmp[0][1]+dp[s][1][1]+1,tmp[0][0]+dp[s][1][0]+3),min(tmp[1][1]+dp[s][0][0]+2,tmp[1][0]+dp[s][0][1]+4));
		dp[x][1][1]=min(min(tmp[0][0]+dp[s][1][1]+1,tmp[0][1]+dp[s][1][0]+3),min(tmp[1][0]+dp[s][0][0]+2,tmp[1][1]+dp[s][0][1]+4));
		
		dp[x][2][0]=min(dp[x][2][0],min(tmp[2][1]+dp[s][0][0]+2,tmp[2][0]+dp[s][0][1]+4));
		dp[x][2][1]=min(dp[x][2][1],min(tmp[2][0]+dp[s][0][0]+2,tmp[2][1]+dp[s][0][1]+4));
		
		dp[x][2][0]=min(dp[x][2][0],min(tmp[0][1]+dp[s][2][0]+2,tmp[0][0]+dp[s][2][1]+4));
		dp[x][2][1]=min(dp[x][2][1],min(tmp[0][0]+dp[s][2][0]+2,tmp[0][1]+dp[s][2][1]+4));
		
		dp[x][2][0]=min(dp[x][2][0],min(tmp[1][0]+dp[s][1][1],tmp[1][1]+dp[s][1][0]+2));
		dp[x][2][1]=min(dp[x][2][1],min(tmp[1][1]+dp[s][1][1],tmp[1][0]+dp[s][1][0]+2));
	}
	dp[x][1][0]=min(dp[x][1][0],dp[x][0][1]+1);
	dp[x][1][1]=min(dp[x][1][1],dp[x][0][0]+1);
	dp[x][2][0]=min(dp[x][2][0],dp[x][1][0]);
	dp[x][2][1]=min(dp[x][2][1],dp[x][1][1]);
}
void solve(){
	int n;cin>>n;
	rep(i,1,n){
		char ch;cin>>ch;
		c[i]=ch-'0';
	}
	memset(dp,0x3f,sizeof(dp));
	rep(i,2,n){
		int u,v;cin>>u>>v;
		e[u].pb(v);e[v].pb(u);
	}
	rep(i,1,n) if(c[i]==0){
		pdfs(i,0);dfs(i,0);
		cout<<dp[i][2][1]<<'\n';
		return;
	}
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);cout.tie(0);
    // int T;for(cin>>T;T--;)
        solve();
    return 0;
}
posted @ 2022-11-20 19:15  ZCETHAN  阅读(22)  评论(0编辑  收藏  举报