LOJ#6289. 花朵 树链剖分+分治NTT
本来以为这道题会非常难调,但是没想到调了不到 5 分钟就 A 了.
由于基于多项式的运算都可以方便地进行封装,所以细节就不是很多(或者说几乎没有细节)
题意:给定一棵树,每个点有点权,求对于所有大小为 $m$ 的独立集的点权之积的和.
数据范围:$n,m \leqslant 8 \times 10^4$.
先考虑一个十分显然的 $O(n^2)$ 暴力:
令 $f[x][i],g[x][i]$ 分别表示点 $x$ 选/不选的情况下独立集大小为 $i$ 的点积 之和.
考虑将 $x$ 与 $x$ 的一个儿子 $y$ 合并:$f[x][i+j]=f[x][i] \times f[y][j]$,$g$ 同理.
然后 $x$ 的初始值是:$f[x][1]=w[x],g[x][0]=1$.
树形DP 卡一下上界复杂度是 $O(n^2)$ 的.
不难发现,上述 $f[x][i+j] = f[x][i] \times f[y][j]$ 是一个卷积的形式.
如果是菊花图或者链的话可以直接用 NTT/分治NTT 来做.
正解的话考虑进行轻重路径剖分:
对于一条重链来说,先求出该重链中每个点轻儿子为根的多项式 $f,g$,然后对于重链中每个点都将其轻儿子与该点合并.
最后对于一条重链进行分治,求出该重链链顶为根的多项式.
分析一下时间复杂度:
考虑一条重链链顶为根的子树会被卷多少次:其祖先中每一条重链都会将其贡献一次.
那么树链剖分中一个点有 $O(\log n)$ 个祖先,而每次卷积的时候对链分治的复杂度是 $O(n \log^2 n)$.
总复杂度就是 $O(n \log^3 n)$,但是由于树链剖分的常数比较小,跑的并不慢.
code:
#include <queue>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define N 1000009
#define ll long long
#define mod 998244353
#define pb push_back
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int m;
int A[N<<2],B[N<<2];
int tim,edges,n;
int size[N],son[N],top[N],hd[N],to[N<<1],nex[N<<1],fa[N],dep[N];
int dfn[N],bu[N],si[N],val[N];
void add(int u,int v) {
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
int ADD(int x,int y) {
return (ll)(x+y)%mod;
}
int DEC(int x,int y) {
return (ll)(x-y+mod)%mod;
}
int MUL(int x,int y) {
return (ll)x*y%mod;
}
int qpow(int x,int y) {
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod) {
if(y&1) tmp=(ll)tmp*x%mod;
}
return tmp;
}
int get_inv(int x) {
return qpow(x,mod-2);
}
void NTT(int *a,int len,int op) {
for(int i=0,k=0;i<len;++i) {
if(i>k) {
swap(a[i],a[k]);
}
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int l=1;l<len;l<<=1) {
int wn=qpow(3,(mod-1)/(l<<1));
if(op==-1) wn=get_inv(wn);
for(int i=0;i<len;i+=l<<1) {
int w=1;
for(int j=0;j<l;++j) {
int x=a[i+j],y=(ll)w*a[i+j+l]%mod;
a[i+j]=(ll)(x+y)%mod;
a[i+j+l]=(ll)(x-y+mod)%mod;
w=(ll)w*wn%mod;
}
}
}
if(op==-1) {
int iv=get_inv(len);
for(int i=0;i<len;++i) {
a[i]=(ll)a[i]*iv%mod;
}
}
}
struct poly {
int len;
vector<int>a;
poly() { len=0,a.clear(); }
void push(int x) {
a.pb(x),++len;
}
void resize(int x) {
a.resize(x),len=x;
}
poly operator*(const poly &b) const {
int lim;
for(lim=1;lim<len+b.len-1;lim<<=1);
for(int i=0;i<lim;++i) A[i]=B[i]=0;
for(int i=0;i<len;++i) A[i]=a[i];
for(int i=0;i<b.len;++i) B[i]=b.a[i];
NTT(A,lim,1),NTT(B,lim,1);
for(int i=0;i<lim;++i) {
A[i]=(ll)A[i]*B[i]%mod;
}
NTT(A,lim,-1);
poly c;
for(int i=0;i<len+b.len-1;++i) {
c.push(A[i]);
}
if(c.len>m+1) c.resize(m+1);
return c;
}
poly operator+(const poly &b) const {
poly c;
c.resize(max(len,b.len));
for(int i=0;i<c.len;++i) c.a[i]=0;
for(int i=0;i<c.len;++i) {
if(i<len) c.a[i]=ADD(c.a[i],a[i]);
if(i<b.len) c.a[i]=ADD(c.a[i],b.a[i]);
}
return c;
}
poly operator-(const poly &b) const {
poly c;
c.resize(max(len,b.len));
for(int i=0;i<c.len;++i) c.a[i]=0;
for(int i=0;i<c.len;++i) {
if(i<len) c.a[i]=ADD(c.a[i],a[i]);
if(i<b.len) c.a[i]=DEC(c.a[i],b.a[i]);
}
return c;
}
}f0[N],f1[N],g[2][N];
struct data {
poly f00,f01,f10,f11;
data operator+(const data &b) const {
data c;
c.f00=(f01*b.f00)+(f00*(b.f00+b.f10));
c.f11=(f11*b.f01)+(f10*(b.f11+b.f01));
c.f01=(f01*b.f01)+(f00*(b.f01+b.f11));
c.f10=(f11*b.f00)+(f10*(b.f10+b.f00));
return c;
}
}tmp;
void dfs1(int x,int ff) {
fa[x]=ff,dep[x]=dep[ff]+1,size[x]=1;
for(int i=hd[x];i;i=nex[i]) {
int y=to[i];
if(y==ff) continue;
dfs1(y,x);
size[x]+=size[y];
if(size[y]>size[son[x]]) son[x]=y;
}
}
void dfs2(int x,int tp) {
top[x]=tp;
dfn[x]=++tim;
bu[tim]=x;
++si[tp];
if(son[x]) {
dfs2(son[x],tp);
}
for(int i=hd[x];i;i=nex[i]) {
if(to[i]!=fa[x]&&to[i]!=son[x]) {
dfs2(to[i],to[i]);
}
}
}
poly calc(int l,int r,int d) {
if(l==r) {
return g[d][l];
}
int mid=(l+r)>>1;
return calc(l,mid,d)*calc(mid+1,r,d);
}
data solve(int l,int r) {
if(l==r) {
int u=bu[l];
data e;
e.f00=f0[u];
e.f11=f1[u];
return e;
}
int mid=(l+r)>>1;
return solve(l,mid)+solve(mid+1,r);
}
int main() {
// setIO("input");
int x,y,z;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) scanf("%d",&val[i]);
for(int i=1;i<n;++i) {
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs1(1,0),dfs2(1,1);
for(int i=1;i<=n;++i) {
f0[i].push(1);
f1[i].push(0);
f1[i].push(val[i]);
}
for(int i=n;i>=1;--i) {
int p=bu[i];
if(top[p]==p) {
for(int j=dfn[p];j<=dfn[p]+si[p]-1;++j) {
x=bu[j];
int p0=0,p1=0;
for(int e=hd[x];e;e=nex[e]) {
y=to[e];
if(y==son[x]||y==fa[x]) continue;
g[0][++p0]=f0[y]+f1[y];
g[1][++p1]=f0[y];
}
if(p0) f0[x]=calc(1,p0,0);
if(p1) f1[x]=f1[x]*calc(1,p1,1);
}
tmp=solve(dfn[p],dfn[p]+si[p]-1);
f0[p]=tmp.f01+tmp.f00;
f1[p]=tmp.f10+tmp.f11;
}
}
f0[1].resize(m+1);
f1[1].resize(m+1);
printf("%d\n",(ll)(f0[1].a[m]+f1[1].a[m])%mod);
return 0;
}

浙公网安备 33010602011771号