loj#2340. 「WC2018」州区划分

FWT&&FMT板子

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
const int _=1e2;
const int maxn=21+5;
const int fbin=(1<<21)+_;
const LL mod=998244353;
LL quick_pow(LL A,int p)
{
    LL ret=1;
    while(p!=0)
    {
        if(p%2==1)ret=ret*A%mod;
        A=A*A%mod;p/=2;
    }
    return ret;
}

//--------------------------------------------------def--------------------------------------------------------

struct node
{
    int x,y,next;
}a[maxn*maxn];int len,last[maxn];
void ins(int x,int y)
{
    len++;
    a[len].x=x;a[len].y=y;
    a[len].next=last[x];last[x]=len;
}
int w[maxn],du[maxn];
int fa[maxn];
int findfa(int x)
{
    if(x==fa[x])return x;
    fa[x]=findfa(fa[x]);return fa[x];
}

//--------------------------------------------pre--------------------------------------------------------------

int cnt[fbin];
LL h[fbin],g[maxn][fbin],f[maxn][fbin];
void FWT(LL *a,int n,int op)
{
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++)
                if(op==1)a[j+k+i]=(a[j+k+i]+a[j+k])%mod;
                else a[j+k+i]=(a[j+k+i]-a[j+k]+mod)%mod;
}
int main()
{
    int n,li,m,p,x,y;
    scanf("%d%d%d",&n,&m,&p); li=(1<<n);
    len=1;
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d",&x,&y);
        ins(x,y),ins(y,x);
    }
    for(int i=1;i<=n;i++)scanf("%d",&w[i]);
    for(int zt=1;zt<li;zt++)
    {
        for(int i=1;i<=n;i++)
            if((1<<(i-1))&zt)h[zt]+=w[i],cnt[zt]++;
        h[zt]=quick_pow(quick_pow(h[zt],mod-2),p);
        
        bool bk=true;
        memset(du,0,sizeof(du));
        for(int i=1;i<=n;i++)fa[i]=i;
        for(int i=2;i<=len;i+=2)
            if( ((1<<(a[i].x-1))&zt) && ((1<<(a[i].y-1))&zt) )
                    du[a[i].x]++,du[a[i].y]++,fa[findfa(a[i].x)]=fa[findfa(a[i].y)];
        int rt=0;
        for(int i=1;i<=n;i++)
        {
            if(du[i]%2==1){bk=false;break;}
            if((1<<(i-1))&zt)
            {
                if(rt==0)rt=findfa(i);
                else if(rt!=findfa(i)){bk=false;break;}
            }
        }
        if(bk==false)
        {
            for(int i=1;i<=n;i++)
                if((1<<(i-1))&zt)g[cnt[zt]][zt]+=w[i];
            g[cnt[zt]][zt]=quick_pow(g[cnt[zt]][zt],p);
        }
    }
    
    //......pre.........
    
    for(int i=0;i<=n;i++)FWT(g[i],li,1);
    f[0][0]=1;FWT(f[0],li,1);
    for(int i=1;i<=n;i++)
    {
        for(int j=0;j<i;j++)
            for(int zt=0;zt<li;zt++)
            f[i][zt]=(f[i][zt]+f[j][zt]*g[i-j][zt])%mod;
        FWT(f[i],li,-1);
        for(int zt=0;zt<li;zt++)
        {
            if(cnt[zt]!=i)f[i][zt]=0;
            f[i][zt]=f[i][zt]*h[zt]%mod;
        }
        if(i!=n)FWT(f[i],li,1);
    }
    printf("%lld\n",f[n][li-1]);
    
    return 0;
}
FWT or
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
const int _=1e2;
const int maxn=21+5;
const int fbin=(1<<21)+_;
const LL mod=998244353;
const LL inv2=mod/2+1;
LL quick_pow(LL A,int p)
{
    LL ret=1;
    while(p!=0)
    {
        if(p%2==1)ret=ret*A%mod;
        A=A*A%mod;p/=2;
    }
    return ret;
}

//--------------------------------------------------def--------------------------------------------------------

struct node
{
    int x,y,next;
}a[maxn*maxn];int len,last[maxn];
void ins(int x,int y)
{
    len++;
    a[len].x=x;a[len].y=y;
    a[len].next=last[x];last[x]=len;
}
int w[maxn],du[maxn];
int fa[maxn];
int findfa(int x)
{
    if(x==fa[x])return x;
    fa[x]=findfa(fa[x]);return fa[x];
}

//--------------------------------------------pre--------------------------------------------------------------

int cnt[fbin];
LL h[fbin],g[maxn][fbin],f[maxn][fbin];
void FWT(LL *a,int n,int op)
{
    for(int i=1;i<n;i<<=1)
        for(int j=0;j<n;j+=(i<<1))
            for(int k=0;k<i;k++)
            {
                LL t1=a[j+k],t2=a[j+k+i];
                a[j+k]=(t1+t2)%mod;
                a[j+k+i]=(t1-t2+mod)%mod;
                if(op==-1)a[j+k]=a[j+k]*inv2%mod,a[j+k+i]=a[j+k+i]*inv2%mod;
            }
}
int main()
{
    int n,li,m,p,x,y;
    scanf("%d%d%d",&n,&m,&p); li=(1<<n);
    len=1;
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d",&x,&y);
        ins(x,y),ins(y,x);
    }
    for(int i=1;i<=n;i++)scanf("%d",&w[i]);
    for(int zt=1;zt<li;zt++)
    {
        for(int i=1;i<=n;i++)
            if((1<<(i-1))&zt)h[zt]+=w[i],cnt[zt]++;
        h[zt]=quick_pow(quick_pow(h[zt],mod-2),p);
        
        bool bk=true;
        memset(du,0,sizeof(du));
        for(int i=1;i<=n;i++)fa[i]=i;
        for(int i=2;i<=len;i+=2)
            if( ((1<<(a[i].x-1))&zt) && ((1<<(a[i].y-1))&zt) )
                    du[a[i].x]++,du[a[i].y]++,fa[findfa(a[i].x)]=fa[findfa(a[i].y)];
        int rt=0;
        for(int i=1;i<=n;i++)
        {
            if(du[i]%2==1){bk=false;break;}
            if((1<<(i-1))&zt)
            {
                if(rt==0)rt=findfa(i);
                else if(rt!=findfa(i)){bk=false;break;}
            }
        }
        if(bk==false)
        {
            for(int i=1;i<=n;i++)
                if((1<<(i-1))&zt)g[cnt[zt]][zt]+=w[i];
            g[cnt[zt]][zt]=quick_pow(g[cnt[zt]][zt],p);
        }
    }
    
    //......pre.........
    
    for(int i=0;i<=n;i++)FWT(g[i],li,1);
    f[0][0]=1;FWT(f[0],li,1);
    for(int i=1;i<=n;i++)
    {
        for(int j=0;j<i;j++)
            for(int zt=0;zt<li;zt++)
            f[i][zt]=(f[i][zt]+f[j][zt]*g[i-j][zt])%mod;
        FWT(f[i],li,-1);
        for(int zt=0;zt<li;zt++)
        {
            if(cnt[zt]!=i)f[i][zt]=0;
            f[i][zt]=f[i][zt]*h[zt]%mod;
        }
        if(i!=n)FWT(f[i],li,1);
    }
    printf("%lld\n",f[n][li-1]);
    
    return 0;
}
FWT xor

 

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
const int _=1e2;
const int maxn=21+5;
const int fbin=(1<<21)+_;
const LL mod=998244353;
const LL inv2=mod/2+1;
LL quick_pow(LL A,int p)
{
    LL ret=1;
    while(p!=0)
    {
        if(p%2==1)ret=ret*A%mod;
        A=A*A%mod;p/=2;
    }
    return ret;
}

//--------------------------------------------------def--------------------------------------------------------

struct node
{
    int x,y,next;
}a[maxn*maxn];int len,last[maxn];
void ins(int x,int y)
{
    len++;
    a[len].x=x;a[len].y=y;
    a[len].next=last[x];last[x]=len;
}
int w[maxn],du[maxn];
int fa[maxn];
int findfa(int x)
{
    if(x==fa[x])return x;
    fa[x]=findfa(fa[x]);return fa[x];
}

//--------------------------------------------pre--------------------------------------------------------------

int cnt[fbin];
LL h[fbin],g[maxn][fbin],f[maxn][fbin];
void FMT(LL *a,int n,int li,int op)
{
    for(int i=1;i<=n;i++)
        for(int zt=0;zt<li;zt++)
            if((1<<(i-1))&zt)a[zt]=(a[zt]+op*a[zt^(1<<i-1)]+mod)%mod;
}
int main()
{
    int n,li,m,p,x,y;
    scanf("%d%d%d",&n,&m,&p); li=(1<<n);
    len=1;
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d",&x,&y);
        ins(x,y),ins(y,x);
    }
    for(int i=1;i<=n;i++)scanf("%d",&w[i]);
    for(int zt=1;zt<li;zt++)
    {
        for(int i=1;i<=n;i++)
            if((1<<(i-1))&zt)h[zt]+=w[i],cnt[zt]++;
        h[zt]=quick_pow(quick_pow(h[zt],mod-2),p);
        
        bool bk=true;
        memset(du,0,sizeof(du));
        for(int i=1;i<=n;i++)fa[i]=i;
        for(int i=2;i<=len;i+=2)
            if( ((1<<(a[i].x-1))&zt) && ((1<<(a[i].y-1))&zt) )
                    du[a[i].x]++,du[a[i].y]++,fa[findfa(a[i].x)]=fa[findfa(a[i].y)];
        int rt=0;
        for(int i=1;i<=n;i++)
        {
            if(du[i]%2==1){bk=false;break;}
            if((1<<(i-1))&zt)
            {
                if(rt==0)rt=findfa(i);
                else if(rt!=findfa(i)){bk=false;break;}
            }
        }
        if(bk==false)
        {
            for(int i=1;i<=n;i++)
                if((1<<(i-1))&zt)g[cnt[zt]][zt]+=w[i];
            g[cnt[zt]][zt]=quick_pow(g[cnt[zt]][zt],p);
        }
    }
    
    //......pre.........
    
    for(int i=0;i<=n;i++)FMT(g[i],n,li,1);
    f[0][0]=1;FMT(f[0],n,li,1);
    for(int i=1;i<=n;i++)
    {
        for(int j=0;j<i;j++)
            for(int zt=0;zt<li;zt++)
            f[i][zt]=(f[i][zt]+f[j][zt]*g[i-j][zt])%mod;
        FMT(f[i],n,li,-1);
        for(int zt=0;zt<li;zt++)
        {
            if(cnt[zt]!=i)f[i][zt]=0;
            f[i][zt]=f[i][zt]*h[zt]%mod;
        }
        if(i!=n)FMT(f[i],n,li,1);
    }
    printf("%lld\n",f[n][li-1]);
    
    return 0;
}

 

posted @ 2019-04-10 21:49  AKCqhzdy  阅读(134)  评论(0编辑  收藏  举报