题解 CF1326F2 Wise Men (Hard Version)

CF1326F2 Wise Men (Hard Version)

题目大意

\(n\) 个人。给出 \(n\) 个人的「认识情况」(双向且保证合法)。

对于每个长度为 \(n\) 的排列 \(p\),可以用它生成一个长度为 \(n - 1\)\(01\)\(s\)。其中 \(s_i\)\(1\) 当且仅当 \(p_{i}\)\(p_{i + 1}\) 认识。

对于所有 \(2^{n - 1}\)\(01\) 串,分别求出它可以由多少个排列生成。

数据范围:\(2\leq n\leq 18\)

前置知识:快速沃尔什变换(FWT)

or卷积

给出两个序列\(a,b\),求一个序列\(c\),使得\(c_i=\sum_{j\operatorname{OR}k=i}a_jb_k\)

仿照FFT的思路,我们构造两个序列\(FWT(a),FWT(b)\)(对应了FFT里的点值),使得\(FWT(c)[i]=FWT(a)[i]\cdot FWT(b)[i]\)。然后再对\(FWT(c)\)做逆变换,得到\(c\)

FWT算法的结论是:对于or卷积,\(FWT(a)[i]=\sum_{j\operatorname{OR}i=i}a_j\)。可以发现,\(j\operatorname{OR}i=i\)就等价于“\(j\)\(i\)的一个子集”。

值得一提的是,根据这个定义,FWT-or就相当于是做高维前缀和;FWT-or的逆变换(IFWT-or)就相当于是高维前缀和的逆变换(差分)

在实现时,对于一个最高次项为\(2^n\)的多项式\(a\),把它分成\(a_0,a_1\)两部分,分别表示前面的\(2^{n-1}\)项和后面的\(2^{n-1}\)项,则:

\[FWT(a)=\begin{cases} (FWT(a_0),FWT(a_0+a_1))&&n>0\\ a&&n=0 \end{cases} \]

这个逗号是啥意思?因为\(FWT(a)\)是一个长度为\(2^n\)的序列,因此逗号左边就是序列的前\(2^{n-1}\)项,右边就是序列的后\(2^{n-1}\)项。

而逆变换就把这个过程反过来即可,即:

\[IFWT(a)=\begin{cases} (IFWT(a_0),IFWT(a_1-a_0))&& n>0\\ a&&n=0 \end{cases} \]

and卷积

给出两个序列\(a,b\),求一个序列\(c\),使得\(c_i=\sum_{j\operatorname{AND}k=i}a_jb_k\)

对于and卷积,\(FWT(a)[i]=\sum_{j\operatorname{AND}i=i}a_j\)。可以发现,\(j\operatorname{AND}i=i\)就等价于“\(i\)\(j\)的一个子集”,和or卷积恰好相反。

同样可以看出,根据这个定义,FWT-and就相当于是做高维后缀和;FWT-and的逆变换(IFWT-and)就相当于是高维后缀和的逆变换(差分)

在实现时,

\[FWT(a)=\begin{cases} (FWT(a_0+a_1),FWT(a_1))&&n>0\\ a&&n=0 \end{cases} \]

同理可以做逆变换:

\[IFWT(a)=\begin{cases} (IFWT(a_0-a_1),IFWT(a_1))&&n>0\\ a&&n=0 \end{cases} \]

xor卷积

与本题无关。只是顺带提一下做法:

\[FWT(a)=\begin{cases} (FWT(a_0+a_1),FWT(a_0-a_1))&&n>0\\ a&&n=0 \end{cases} \]

于是可知,逆变换为:

\[IFWT(a)=\begin{cases} (IFWT(\frac{a_0+a_1}{2}),IFWT(\frac{a_0-a_1}{2}))&&n>0\\ a&&n=0 \end{cases} \]

本题题解

我们设\(ans(s)\)表示串\(s\)的答案。直接求\(ans(s)\)不好求,考虑集合中至少包含\(s\)的答案,即所有\(s\subseteq S\)\(ans(S)\)之和,记为\(ans'(s)\)。然后我们对\(ans'\)数组做IFWT-and卷积,就可以求出所有\(ans(s)\)

把朋友之间的关系看做一张无向图。我们定义一条链的长度为它经过的节点数

那么对于一个长度为\(n-1\)的01串\(s\),它代表的其实是图中的若干条链。具体来讲,如果在串\(s\)后面补上一个\(0\),那么:

  • 串中每段连续的\(1\)是一条链。如果有\(x\)\(1\),则链的长度为\(x+1\)
  • 每个\(0\)是单独的一个节点(也就是一条长度为\(1\)的链)。特别地:一段连续的\(1\)之后的第一个\(0\)除外,它这个位置上的节点已经被计入了上一条连续的\(1\)组成的链中。

按照上述规则,不难发现,所有链的长度之和恰好为\(n\)。而对于一个01串\(s\)来说,\(ans'(s)\)只取决于它划分出的链的长度的可重集。例如:\(ans'(111011)=ans'(110111)\),因为它们的这个可重集都是\(\{1,3,4\}\)

又因为所有链的长度之和恰好为\(n\),故本质不同的可重集数量只有\(P(n)\)种,其中\(P(n)\)表示\(n\)的划分数。\(P(18)=385\)。于是我们只需要对这\(P(n)\)个“链的长度的可重集”,分别求答案。

\(f_{i,mask}\)表示对于一个大小为\(i\)的节点集合\(mask\),图中有多少条链,恰好经过\(mask\)中的这些节点。

如果我们求出了\(f_{i,mask}\)数组,那么对于一个“链的长度的可重集”\(T\),它的答案就是\(\displaystyle\sum_{m_1,\dots,m_{|T|}}\ \prod_{i=1}^{|T|}f_{len(T_i),m_i}\)。其中\(len(T_i)\)表示\(T\)中第\(i\)条链的长度。前面的\(\sum\)枚举的是一个\(m_i\)数组,表示对每个\(i\)各取一个大小为\(len(T_i)\)的点集\(m_i\),要求这些\(m_i\)的并为\([1,n]\)且互相不交。容易发现只要并为\([1,n]\)就必然互相不交,因为它们的\(len(T_i)\)之和为\(n\)。所以我们可以做一个FMT-or卷积。把所有\(f_{len(T_i)}\)\(|T|\)个序列卷起来。卷积结果的\(2^n-1\)项前的系数即为\(T\)这个可重集的答案。

现在最后的问题是\(f_{i,mask}\)数组怎么求。可以做简单的状压DP。设\(dp[mask][j]\)表示经过了\(mask\)中的这些节点,最后一个经过的节点为\(j\)的链的数量。转移时枚举一个不在\(mask\)中切与\(j\)有连边的点作为下一个点即可。则\(f_{i,mask}=\sum_{j=1}^{n}dp[mask][j]\)

DP求\(f_{i,mask}\)的复杂度为\(O(2^nn^2)\),之后枚举每个可重集,求答案的复杂度为\(O(P(n)2^nn)\),其中\(P(18)=385\)

参考代码

//problem:CF1326F2
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
const int MAXN=1<<20;
char buf[MAXN],*S,*T;
inline char getchar(){
	if(S==T){
		T=(S=buf)+fread(buf,1,MAXN,stdin);
		if(S==T)return EOF;
	}
	return *S++;
}
}//namespace Fread
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline int readbit(){
	char ch=getchar();
	while(ch<'0'||ch>'1')ch=getchar();
	return ch-'0';
}
/*  ------  by:duyi  ------  */ // myt天下第一
const int MAXN=18;
int n,a[MAXN+5][MAXN+5];
ll dp[1<<MAXN][MAXN+5],f[MAXN+5][1<<MAXN],h[400],ans[1<<MAXN];

int cnt;
map<vector<int>,int>mp;
vector<int>vec[400],tmp;
void dfs(int cur,int lst){
	if(cur==n+1){
		mp[tmp]=++cnt;
		vec[cnt]=tmp;
		return;
	}
	if(n-cur+1<lst)return;
	for(int i=lst;cur+i-1<=n;++i){
		tmp.pb(i);
		dfs(cur+i,i);
		tmp.pop_back();
	}
}

int bitcnt(uint x){
	int res=0;
	for(int j=0;j<=31;++j)res+=((x>>j)&1u);
	return res;
}
void fwt_or(ll *f,uint n,int flag){
	// FWT_or(A)[i] = \sum_{j|i=i} A[j]
	//即:j是i的一个子集
	for(uint i=1;i<n;i<<=1){
		for(uint j=0;j<n;j+=(i<<1)){
			for(uint k=0;k<i;++k){
				f[i+j+k]+=f[j+k]*flag;
			}
		}
	}
}
void fwt_and(ll *f,uint n,int flag){
	// FWT_and(A)[i] = \sum_{j&i=i} A[j]
	//即:i是j的一个子集
	for(uint i=1;i<n;i<<=1){
		for(uint j=0;j<n;j+=(i<<1)){
			for(uint k=0;k<i;++k){
				f[j+k]+=f[i+j+k]*flag;
			}
		}
	}
}

int main() {
	n=read();
	dfs(1,1);//搜出所有划分数
	for(int i=1;i<=n;++i)for(int j=1;j<=n;++j)a[i][j]=readbit();
	
	//dp[mask][j] 表示经过了mask中这些点,以j结尾的链有多少.
	//用来求出 f[i][mask] 表示经过了大小为i的集合mask的链的数量
	for(int i=1;i<=n;++i)dp[1u<<(i-1)][i]=1;
	for(uint i=1;i<(1u<<n);++i){
		for(int j=1;j<=n;++j)if((i>>(j-1))&1u){
			for(int k=1;k<=n;++k)if(a[j][k]&&!((i>>(k-1))&1u)){
				dp[i|(1u<<(k-1))][k]+=dp[i][j];
			}
		}
		int t=bitcnt(i);
		for(int j=1;j<=n;++j)f[t][i]+=dp[i][j];
	}
	
	for(int i=1;i<=n;++i)fwt_or(f[i],1u<<n,1);
	static ll IE[1<<MAXN],tmp[1<<MAXN];
	IE[0]=1;
	fwt_or(IE,1u<<n,1);
	for(int i=1;i<=cnt;++i){
		for(uint j=0;j<(1u<<n);++j)tmp[j]=IE[j];
		for(uint j=0;j<vec[i].size();++j){
			for(uint k=0;k<(1u<<n);++k)tmp[k]=(ll)tmp[k]*f[vec[i][j]][k];
		}
		fwt_or(tmp,1u<<n,-1);
		h[i]=tmp[(1u<<n)-1];
	}
	for(uint i=0;i<(1u<<(n-1));++i){
		vector<int>tmp;
		for(int j=0;j<=n-1;){
			int jj=j+1;
			while(jj-1<=n-2 && ((i>>(jj-1))&1u))++jj;
			tmp.pb(jj-j);
			j=jj;
		}
		sort(tmp.begin(),tmp.end());
		//for(uint j=0;j<tmp.size();++j)cout<<tmp[j]<<" ";cout<<endl;
		assert(mp.count(tmp));
		ans[i]=h[mp[tmp]];
	}
	fwt_and(ans,(1u<<(n-1)),-1);
	for(uint i=0;i<(1u<<(n-1));++i)printf("%lld ",ans[i]);
	return 0;
}
posted @ 2020-03-20 20:42  duyiblue  阅读(458)  评论(0编辑  收藏  举报