Educational Codeforces Round 132 E,F

E

容易发现,我们在一个点进行修改,一定可以将经过这个点的非法路径数变成\(0\)。具体的方案就是将\(i\)异或上\(2^{p_{i}}\),其中\(p_{i}>30\),并且\(p\)要两两不同。(比如\(p_{i}=30+i\)就是一组可行的\(p\))

此时我们有一个贪心策略:为了让更多非法路径变成合法的,那么我们修改的点深度要尽可能小。

那么我们从叶子向根考虑:如果当前点\(p\),存在一个非法路径\((u,v)\),满足\(lca(u,v)=p\),那么我们一定修改\(p\),否则再往上就没有这条路径的点了,也无法让这条路径变成合法的了。

如何判断是否存在一个非法路径\((u,v)\),满足\(lca(u,v)=p\)呢?我们发现:设\(f(i)\)为从根到点\(i\)的路径的异或值,那么路径\((u,v)\)的权值异或和就等于\(f(u)\oplus f(v)\oplus a_{lca}\)。那么用\(S_{pos}\)存储当前子树的\(f\)的集合,在遇到一个儿子时,对于\(t\in S_{son}\),我们查询\(t\oplus a_{pos}\)是否在\(S_{pos}\)中出现过,如果有,那么意味着有一条异或和等于\(0\)的路径,其lca等于pos。最后再将\(S_{son}\)并入\(S_{pos}\)中即可。

注意修改\(p\)之后,我们不仅让\(lca=p\)的非法路径成为合法路径,同时还让所有经过\(p\)往根走的非法路径也变成合法的了。故如果我们要修改当前点\(p\),那么\(S_{p}\)需要清空。

具体代码如下:

void cnt(int pos,int fa=0) {
	int is=0;
	val[pos].insert(f[pos]);
	for(auto nxt : G[pos]) {
		if(nxt==fa) continue;
		cnt(nxt,pos);
		for(auto item : val[nxt]) {
			if(val[pos].find(item^a[pos])!=val[pos].end()) {
				is=1;
			}
		}
		val[pos].insert(val[nxt].begin(),val[nxt].end());
	}
	if(is) {
		ans++;
		val[pos].clear();
	}
}

但是这样时间复杂度有可能达到\(O(n^2\log n)\),因为将\(S_{son}\)并入\(S_{pos}\)中这个操作有可能是\(O(n \log n)\)的。如果我们采用启发式合并(将小的Set合并到大的Set中),那么时间复杂度将优化到\(O(n\log ^2 n)\)

整体代码如下:

#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl

const int maxn=200005;
int n,ans;
int a[maxn],f[maxn];
std::set<int> val[maxn];
std::vector<int> G[maxn];

void dfs(int pos,int fa=0) {
	f[pos]=f[fa]^a[pos];
	for(auto nxt : G[pos]) {
		if(nxt==fa) continue;
		dfs(nxt,pos);
	}
}

void cnt(int pos,int fa=0) {
	int is=0;
	val[pos].insert(f[pos]);
	for(auto nxt : G[pos]) {
		if(nxt==fa) continue;
		cnt(nxt,pos);
		if(val[pos].size()<val[nxt].size()) {
			std::swap(val[pos],val[nxt]);
		}
		for(auto item : val[nxt]) {
			if(val[pos].find(item^a[pos])!=val[pos].end()) {
				is=1;
			}
		}
		val[pos].insert(val[nxt].begin(),val[nxt].end());
	}
	if(is) {
		ans++;
		val[pos].clear();
	}
}

int main() {
	scanf("%d",&n);
	for(int i=1;i<=n;i++) {
		scanf("%d",&a[i]);
	}
	for(int i=1;i<=n-1;i++) {
		int x,y; scanf("%d%d",&x,&y);
		G[x].push_back(y);
		G[y].push_back(x);
	}
	dfs(1);	
	cnt(1);
	printf("%d\n",ans);
	return 0;
}

F
我们将这个问题放到Trie上考虑。容易发现我们其实是要选择若干个叶子,并且\(c_{i}\)其实就是对一个子树限制其内部选择的叶子个数不能超过\(c_{i}\)

那么此时我们可以在Trie上DP。设\(f(i,j)\)表示考虑\(i\)为根的子树,我们最多能选择\(j\)个叶子,则我们大致可以列出一个转移式:

\[f(i,j)=\sum_{k1,k2} f(left,k1)f(right,k2)ways \]

其中\(ways\)是当前点\(C_{i}\)的取值数。

现在我们就来算一下\(ways\)的可能取值:

首先\(k1+k2<j\),那么显然不可能,故\(ways=0\)
\(k1+k2=j\),那么\(C_{i}\geq j\)\(ways=k-j+1\)
\(k1+k2>j\),由于当前方案应该是选择最多点的,若\(C_{i}\)也大于\(j\),那么意味着我们可以取更多的叶子。故\(C_{i}=j\)\(ways=1\)

那么我们可以写出具体的转移方程:

\[f(i,j)=(k-j+1)\sum_{k1+k2=j} f(left,k1)f(right,k2) + \sum_{k1+k2>j} f(left,k1)f(right,k2) \]

边界:在叶子上,如果我们选择\(j\)个点,并且这还是选择点最多的方案,显然\(C_{leaf}=j\)。所以:

\[f(leaf,j)=1 \]

边界:在根节点上,我们没有对应的\(C_{root}\),所以\(k1+k2>j\)的方案是不存在的。

\[f(root,j)=(k-j+1)\sum_{k1+k2=j} f(left,k1)f(right,k2) \]

答案:答案就是\(f(root,t)\)。(为了防止变量重复,这里的\(t\)就是输入时的\(f\))要特别注意,\(t>k\)有时候也是有方案的,比如\(1 100000 200000\),此时要特别注意处理到根节点是否考虑了这种情况。

这样做是\(O(2^nk^2)\)的。虽然转移的式子一眼卷积,但也仅仅能优化到\(O(2^nk\log k)\)

注意到一个奇特的性质——对于相同深度的点,所对应的子树其实长得是一样的。(因为这个Trie是棵满二叉树)所以对于深度相同的点,其DP值也是相同的。

所以修改DP数组的定义——设\(f(i,j)\)表示深度为\(i\)的点,我们取\(j\)个叶子的方案数。

转移方程也是类似的:

\[f(i,j)=(k-j+1)\sum_{k1+k2=j} f(i+1,k1)f(i+1,k2) + \sum_{k1+k2>j} f(i+1,k1)f(i+1,k2) \]

此时再用卷积+后缀和,(FFT或NTT)时间复杂度降为\(O(nk\log k)\),就可以通过此题了。

#include<bits/stdc++.h>
#define debug(...) std::cerr<<#__VA_ARGS__<<" : "<<__VA_ARGS__<<std::endl

const int mod=998244353;

int qpow(int x,int y) {
	int ret=1;
	for(;y;y>>=1,x=1ll*x*x%mod) {
		if(y&1) {
			ret=1ll*ret*x%mod;
		}
	}
	return ret;
}

namespace mul {
	const int N=2097152;
	const int g=3;
	int omega[N];
	void init() {
		int Wn=qpow(g,(mod-1)/N);
		omega[0]=1;
		for(int i=1;i<N;i++) {
			omega[i]=1ll*omega[i-1]*Wn%mod;
		}
	}
	void ntt(int *a,int n,int op) {
		for(int i=0,j=0;i<n;i++) {
			if(i<j) std::swap(a[i],a[j]);
			int k=n; while(j<k) k>>=1,j^=k;
		}
		for(int k=1,step=N>>1;k<n;k<<=1,step>>=1) {
			for(int i=0;i<n;i+=k<<1) {
				for(int j=i,cur=0;j<i+k;j++,cur+=step) {
					int tmp=1ll*a[j+k]*omega[cur]%mod;
					a[j+k]=a[j]-tmp+mod;
					a[j]+=tmp;
					if(a[j+k]>=mod) a[j+k]-=mod;
					if(a[j]>=mod) a[j]-=mod;
				}
			}
		}
		if(op==-1) {
			std::reverse(a+1,a+n);
			int iv2=(mod+1)/2,ivn=1;
			for(int i=1;i<n;i<<=1) {
				ivn=1ll*ivn*iv2%mod;
			}
			for(int i=0;i<n;i++) {
				a[i]=1ll*a[i]*ivn%mod;
			}
		}
	}
} using namespace mul;

int n,k,t;
int A[600005],B[600005],S[600005];
int f[20][200005];

int main() {
	init();
	scanf("%d%d%d",&n,&k,&t);
	int len=1; while(len<=k+k) len<<=1;
	for(int i=0;i<=k;i++) f[n][i]=1;
	for(int i=n-1;i>=0;i--) {
		memset(A,0,sizeof A);
		memset(B,0,sizeof B);
		memset(S,0,sizeof S);
		for(int j=0;j<=k;j++) B[j]=f[i+1][j];
		ntt(B,len,1);
		for(int j=0;j<len;j++) A[j]=1ll*B[j]*B[j]%mod;
		ntt(A,len,-1);
		for(int j=len-1;j>=0;j--) {
			S[j]=A[j]+S[j+1]; if(S[j]>=mod) S[j]-=mod;
		}
		if(!i) {
			std::cout<<A[t]<<std::endl; exit(0);
		} else {
			for(int j=0;j<=k;j++) {
				f[i][j]=1ll*(k-j+1)*A[j]%mod+S[j+1];
				if(f[i][j]>=mod) f[i][j]-=mod;
			}
		}
	}
	return 0;
}
posted @ 2022-07-23 10:43  Nastia  阅读(44)  评论(0)    收藏  举报