CF1616H Keep XOR Low / Trie 树上 DP

题目传送门:CF1616H Keep XOR Low

首先将 \(a\) 中的所有数加入 0/1 Trie 中。

\(pw_i = 2^{sz_i} -1\),即 \(i\) 的子树的所有选法,要去掉空集。

\(f_u\) 表示选择 \(u\) 中的子树满足两两 \(\le x\) 的方案数,假设考虑到 \(x\) 的第 \(d\) 位,令 \(u\) 的左右儿子为 \(lc,rc\)

  • \(x\) 的第 \(d\) 位为 \(0\),那么答案要么在左子树要么在右子树即 \(f_u = f_{lc} + f_{rc}\)

  • \(x\) 的第 \(d\) 位为 \(1\),我们发现答案可以在左子树取,可以在右子树取即 \(pw_{lc}+pw_{rc}\),但是左右我们无法得知,考虑加设状态。

\(f_{u_1,u_2}\) 表示在 \(u_1\)\(u_2\) 子树中选数且 \(u_1\)\(u_2\) 的子树都强制选的方案数。

考虑边界情况,如果到达了叶子那么 \(f_{u_1,u_2} = \begin{cases} pw_{u_1} & u_1 = u_2 \\ pw_{u_1} \times pw_{u_2} & u_1 \neq u_2 \\ \end{cases} \)

假设我们考虑到 \(x\) 的第 \(d\) 位,令 \(u_1\) 的左右儿子为 \(lc_1,rc_1\)\(u_2\) 的左右儿子为 \(lc_2,rc_2\)

  • \(x\) 的第 \(d\) 位为 \(0\),那么答案要么选 \(lc_1\) 的子树和 \(lc_2\) 的子树或 \(rc_1\) 的子树和 \(rc_2\) 的子树,即 \(f_{u_1,u_2} = f_{lc_1,lc_2} + f_{rc_1,rc_2}\)

  • \(x\) 的第 \(d\) 位为 \(1\),那么答案的选取有很多种,令 \(t_1 = f_{lc_1,rc_2},t_2=f_{rc_1,lc_2}\)

  1. \(lc_1\) 的子树和 \(lc_2\) 的子树可以随便选,\(rc_1\)\(rc_2\) 的子树同理,贡献为 \(pw_{lc_1}pw_{lc_2}+pw_{rc_1}pw_{rc_2}\)

  2. 只选择 \(lc_1\)\(rc_2\) 的子树或者 \(lc_2\)\(lc_1\) 的子树,或者全选,贡献为 \(t_1+t_2+t_1t_2\)

  3. 选择 \(lc_1\)\(rc_2\) 的子树剩下 \(rc_1\)\(lc_2\) 的子树选一个,贡献为 \(t_1(pw_{rc_1}+pw_{lc_2})\)

  4. 选择 \(rc_1\)\(lc_2\) 的子树剩下 \(lc_1\)\(rc_2\) 的子树选一个,贡献为 \(t_2(pw_{lc_1}+pw_{rc_2})\)

按照上述方法 dfs 转移即可,注意在 \(u_1=u_2\) 时要按照之前说的转移。

#include<bits/stdc++.h>
#define int long long
#define double long double
using namespace std;
const int N=1.5e5+10,mod=998244353;
inline int read(){
	char c=getchar();
	int f=1,ans=0;
	while(c<48||c>57) f=(c==45?f=-1:1),c=getchar();
	while(c>=48&&c<=57) ans=(ans<<1)+(ans<<3)+(c^48),c=getchar();
	return ans*f;
}
int ch[N*30][2],n,x,sz[N*30],idx=1,pw[N];
inline void insert(int x){
	int p=1;
	for (int i=30;i>=0;i--){
		sz[p]++;
		int tmp=(x>>i)&1;
		if (ch[p][tmp]) p=ch[p][tmp];
		else p=ch[p][tmp]=++idx; 
	}
	sz[p]++;
}
int dfs(int u1,int u2,int d){
	if (!u1||!u2) return 0; 
	if (d<0){
		if (u1==u2) return pw[sz[u1]];
		return pw[sz[u1]]*pw[sz[u2]]%mod;
	}
	int lc1=ch[u1][0],rc1=ch[u1][1],lc2=ch[u2][0],rc2=ch[u2][1];
	if ((x>>d)&1){
		if (u1==u2) return (dfs(lc1,rc1,d-1)+pw[sz[lc1]]+pw[sz[rc1]])%mod;
		int t1=dfs(lc1,rc2,d-1),t2=dfs(rc1,lc2,d-1);
		return (pw[sz[lc1]]*pw[sz[lc2]]%mod+pw[sz[rc1]]*pw[sz[rc2]]%mod+t1*t2%mod+t1+t2+t1*(pw[sz[lc2]]+pw[sz[rc1]])%mod+t2*(pw[sz[lc1]]+pw[sz[rc2]])%mod)%mod;
	}
	else return (dfs(lc1,lc2,d-1)+dfs(rc1,rc2,d-1))%mod;
}
main(){
	n=read(),x=read();pw[0]=1;
	for (int i=1;i<=n;i++) insert(read()),pw[i]=pw[i-1]*2%mod;
	for (int i=0;i<=n;i++) pw[i]=(pw[i]+mod-1)%mod;
	printf("%lld",dfs(1,1,30));
    return 0;
}
posted @ 2026-02-07 14:49  OTn53_qwq  阅读(6)  评论(0)    收藏  举报