【ABC269EX】Antichain
【ABC269EX】Antichain
Description
给出\(n\)个点的树,对于所有\(k\in[1,n]\)
求出选出恰好\(k\)个数的点集的总方案数,满足点集中的点互不为祖先-后代关系
模\(998244353\)
Input
第一行一个数\(n\)
然后一行\(n-1\)个数,第\(i\)个数读入\(i+1\)的父亲
Output
一共\(n\)行,每行一个数表示答案
Sample Input
4
1 2 1
Sample Output
4
2
0
0
Data Constraint
\(2\le n\le 2*10^5\)
Solution
学会了新科技
观察转移式\(f_u=x+\prod f_v\),实际上的度数跟子树大小有关
所以可以用树剖优化
设\(f_u=x+f_{son_u}g_u\),\(g_u\)表示轻儿子的卷积
于是可以写成矩阵形式
从下往上做,每次在链顶用矩阵计算
切换轻边时也分治\(NTT\)一下,貌似是\(O(n\log^3n)\) 不过就跑了1s,爽死我了
Code
#include<bits/stdc++.h>
using namespace std;
#define F(i,a,b) for(int i=a;i<=b;i++)
#define Fd(i,a,b) for(int i=a;i>=b;i--)
#define N 600010
#define mo 998244353
#define LL long long
#define ULL unsigned long long
int rev[N],G1[N],G2[N],fac[N],ifac[N],inv[N];
int mod(int x){return x>=mo?x-mo:x;}
int mi(int x,int y){
if(!y)return 1;
if(y==1)return x;
return y%2?1ll*x*mi(1ll*x*x%mo,y/2)%mo:mi(1ll*x*x%mo,y/2);
}
void init(){
fac[0]=ifac[0]=1;
F(i,1,N-10)fac[i]=1ll*fac[i-1]*i%mo,inv[i]=(i==1?1:1ll*mo/i*mod(mo-1ll*inv[mo%i]%mo)%mo);
ifac[N-10]=mi(fac[N-10],mo-2);
Fd(i,N-11,1)ifac[i]=1ll*ifac[i+1]*(i+1)%mo;
for(int l=1;l<=N-10;l<<=1)G1[l]=mi(3,(mo-1)/(l*2)),G2[l]=mi(G1[l],mo-2);
}
void BRT(int x){F(i,0,x-1)rev[i]=(rev[i>>1]>>1)|((i&1)?(x>>1):0);}
struct poly{
vector<int>val;
poly(int x=0){if(x)val.push_back(x);}
poly(const vector<int>&x){val=x;}
void Rev(){reverse(val.begin(),val.end());}
void ins(int x){val.push_back(x);}
void clear(){vector<int>().swap(val);}
int sz(){return val.size();}
void rsz(int x){val.resize(x);}
void shrink(){for(;sz()&&!val.back();val.pop_back());}
poly modxn(int x){
if(val.size()<=x)return poly(val);
else return poly(vector<int>(val.begin(),val.begin()+x));
}
int operator[](int x)const{
if(x<0||x>=val.size())return 0;
return val[x];
}
void NTT(int x){
static ULL f[N],w[N];
w[0]=1;
F(i,0,sz()-1)f[i]=(((LL)mo<<5)+val[rev[i]])%mo;
for(int mid=1;mid<sz();mid<<=1){
int tmp=(x==1?G1[mid]:G2[mid]);
F(i,1,mid-1)w[i]=w[i-1]*tmp%mo;
for(int i=0;i<sz();i+=(mid<<1)){
F(j,0,mid-1){
int t=w[j]*f[i|j|mid]%mo;
f[i|j|mid]=f[i|j]+mo-t;f[i|j]+=t;
}
}
if(mid==(1<<10)){F(i,0,sz()-1)f[i]%=mo;};
}
if(x==-1){int tmp=inv[sz()];F(i,0,sz()-1)val[i]=f[i]%mo*tmp%mo;}
else{F(i,0,sz()-1)val[i]=f[i]%mo;}
}
void DFT(){NTT(1);}
void IDFT(){NTT(-1);}
friend poly operator*(poly x,poly y){
if(x.sz()<30||y.sz()<30){
if(x.sz()>y.sz())swap(x,y);
poly ret;
ret.rsz(x.sz()+y.sz());
F(i,0,ret.sz()-1){
for(int j=0;j<=i&&j<x.sz();j++)
ret.val[i]=mod(ret.val[i]+1ll*x[j]*y[i-j]%mo);
}
ret.shrink();
return ret;
}
int l=1;
while(l<x.sz()+y.sz()-1)l<<=1;
x.rsz(l);y.rsz(l);BRT(l);
x.DFT();y.DFT();
F(i,0,l-1)x.val[i]=1ll*x[i]*y[i]%mo;
x.IDFT();
x.shrink();
return x;
}
friend poly operator+(poly x,poly y){
poly ret;
ret.rsz(max(x.sz(),y.sz()));
F(i,0,ret.sz()-1)ret.val[i]=mod(x[i]+y[i]);
return ret;
}
friend poly operator-(poly x,poly y){
poly ret;
ret.rsz(max(x.sz(),y.sz()));
F(i,0,ret.sz()-1)ret.val[i]=mod(x[i]-y[i]+mo);
return ret;
}
poly &operator*=(poly x){return (*this)=(*this)*x;}
poly &operator+=(poly x){return (*this)=(*this)+x;}
poly &operator-=(poly x){return (*this)=(*this)-x;}
};
struct div1{
int n;
poly f[N];
poly solve(int l,int r){
if(l==r)return f[l];
int mid=l+r>>1;
return solve(l,mid)*solve(mid+1,r);
}
void push(poly g){f[++n]=g;}
void clear(){n=0;}
}q1;
struct matrix{poly v[2][2];};
matrix operator*(matrix x,matrix y){
return (matrix){
(x.v[0][0]*y.v[0][0])+(x.v[0][1]*y.v[1][0]),
(x.v[0][0]*y.v[0][1])+(x.v[0][1]*y.v[1][1]),
(x.v[1][0]*y.v[0][0])+(x.v[1][1]*y.v[1][0]),
(x.v[1][0]*y.v[0][1])+(x.v[1][1]*y.v[1][1])
};
}
struct div2{
matrix f[N];
matrix solve(int l,int r){
if(l>r)return (matrix){poly(1),0,0,poly(1)};
if(l==r)return f[l];
int mid=l+r>>1;
return solve(mid+1,r)*solve(l,mid);
}
}q2;
poly f[N];
int n,k,p[N],sz[N],son[N],id[N],rk[N],cnt,down[N];
vector<int>e[N];
void dfs1(int u){
sz[u]=1;
for(auto v:e[u]){
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
void dfs2(int u,int pre){
id[++cnt]=u;rk[u]=cnt;down[pre]=cnt;
if(son[u])dfs2(son[u],pre);
for(auto v:e[u])if(v!=son[u])dfs2(v,v);
q1.clear();
q1.push(poly(1));
for(auto v:e[u])if(v!=son[u])q1.push(f[v]);
poly shit=q1.solve(1,q1.n);
q2.f[rk[u]]=(matrix){q1.solve(1,q1.n),poly(),poly(vector<int>{0,1}),poly(1)};
if(!son[u]){
f[u].ins(1);f[u].ins(1);
}else if(u==pre){
matrix tmp=q2.solve(rk[u],down[u]-1);
f[u]=f[id[down[u]]]*tmp.v[0][0]+tmp.v[1][0];
}
}
int main(){
init();
scanf("%d",&n);
F(i,2,n)scanf("%d",&p[i]),e[p[i]].push_back(i);
dfs1(1);
dfs2(1,1);
F(i,1,n)printf("%d\n",f[1][i]);
return 0;
}

浙公网安备 33010602011771号