[BZOJ3771] Triple 题解

《关于贫穷的樵夫拥有 40000 把斧头这件事》。


相当于是多项式乘法,但是得带容斥,具体自己看代码吧。

#include<bits/stdc++.h>
using namespace std;
const int N=3e5+5;
const long double pi=acos(-1);
namespace FFT{
    int rev[N],mx,k;
    struct comn{long double a,b;};
    struct dft{comn fg[N];};
    comn operator+(comn x,comn y){
        return {x.a+y.a,x.b+y.b};
    }comn operator-(comn x,comn y){
        return {x.a-y.a,x.b-y.b};
    }comn operator*(comn x,comn y){
        return {x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a};
    }void operator+=(comn &x,comn y){x=x+y;}
    void operator-=(comn &x,comn y){x=x-y;}
    void operator*=(comn &x,comn y){x=x*y;}
    void init(int n){
        mx=1,k=0,rev[0]=0;
        while(mx<=n) mx*=2,k++;
        for(int i=0;i<mx;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
    }void fft(dft &a,int fl){
        for(int i=0;i<mx;i++)
            if(i<rev[i]) swap(a.fg[i],a.fg[rev[i]]);
        comn o={cos(pi),fl*sin(pi)},w={1,0};
        for(int i=1;i<mx;i*=2,o={cos(pi/i),fl*sin(pi/i)})
            for(int j=0;j<mx;j+=i*2,w={1,0})
                for(int l=j;l<j+i;l++){
                    comn x=a.fg[l],y=w*a.fg[l+i];
                    a.fg[l]+=y,a.fg[l+i]=x-y,w*=o;
                }
    }void operator+=(dft &x,dft &y){
        for(int i=0;i<mx;i++) x.fg[i]+=y.fg[i];
    }void operator-=(dft &x,dft &y){
        for(int i=0;i<mx;i++) x.fg[i]-=y.fg[i];
    }
}using namespace FFT;
int n,ans[N],sum[N],sm[N],m;dft al,be,th,de;
int main(){
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0),cin>>n;
    for(int i=1,x;i<=n;i++){
        cin>>x,al.fg[x].a++,be.fg[x*2].a++;
        m=max(m,x),ans[x]++,sum[x*2]++,sm[x*3]++;
    }m*=3,init(m+1),fft(al,1),fft(be,1);
    for(int i=0;i<mx;i++){
        th.fg[i]=al.fg[i]*al.fg[i];
        de.fg[i]=al.fg[i]*be.fg[i];
    }fft(th,-1),fft(de,-1);
    for(int i=1;i<=m;i++){
        ans[i]+=(int)((th.fg[i].a-sum[i])/2/mx+0.5);
        sm[i]+=((int)(de.fg[i].a/mx+0.5)-sm[i])*3;
    }for(int i=1;i<=m;i++) th.fg[i].a/=mx;
    fft(th,1);
    for(int i=0;i<mx;i++)
        th.fg[i]*=al.fg[i];
    fft(th,-1);
    for(int i=1;i<=m;i++)
        ans[i]+=((int)(th.fg[i].a/mx+0.5)-sm[i])/6;
    for(int i=1;i<=m;i++)
        if(ans[i]) cout<<i<<" "<<ans[i]<<"\n";
    return 0;
}
posted @ 2025-01-19 10:34  white_tiger  阅读(19)  评论(0)    收藏  举报