复习

 

普通平衡树

需要注意的点:
1.哨兵节点提前插入  

2.父亲节点注意一下 

3.细心一点 

#include <bits/stdc++.h>    
#define N 300009  
#define lson s[x].ch[0] 
#define rson s[x].ch[1] 
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std; 
const int inf=1000000009;   
struct data { 
    int v,si,ch[2],f;   
    data() { v=si=ch[0]=ch[1]=f=0; }
}s[N];   
int tot,root;   
inline void pushup(int x) { 
    s[x].si=s[lson].si+s[rson].si+1;  
}
inline int get(int x) { 
    return s[s[x].f].ch[1]==x;   
}
inline void rotate(int x) { 
    int old=s[x].f,fold=s[old].f,which=get(x); 
    if(fold) {
        s[fold].ch[s[fold].ch[1]==old]=x; 
    }
    s[old].ch[which]=s[x].ch[which^1]; 
    if(s[old].ch[which]) { 
        s[s[old].ch[which]].f=old; 
    }
    s[x].ch[which^1]=old,s[old].f=x,s[x].f=fold;    
    pushup(old),pushup(x);
}
void splay(int x,int &tar) { 
    int u=s[tar].f; 
    for(int fa;(fa=s[x].f)!=u;rotate(x)) { 
        if(s[fa].f!=u) {    
            rotate(get(fa)==get(x)?fa:x); 
        }
    }
    tar=x;  
}
void ins(int &x,int ff,int v) { 
    if(!x) { 
        x=++tot;
        s[x].f=ff,s[x].v=v,s[x].si=1;  
        return;   
    }
    ins(s[x].ch[v>s[x].v],x,v); 
    pushup(x);
}      
int get_pre(int x,int v) {     
    if(!x) return -1;        
    if(s[x].v<v) {      
        int det=get_pre(rson,v);   
        return det==-1?x:det;  
    }
    else {
        return get_pre(lson,v);   
    }
}
int get_aft(int x,int v) { 
    if(!x) return -1; 
    if(s[x].v>v) { 
        int det=get_aft(lson,v); 
        return det==-1?x:det; 
    }
    else return get_aft(rson,v);   
}
int get_rank(int v) { 
    int x=get_pre(root,v);     
    splay(x,root);          
    return s[s[x].ch[0]].si;  
}   
int get_kth(int x,int kth) {    
    if(s[lson].si+1==kth) return x;   
    else if(s[lson].si>=kth) return get_kth(lson,kth); 
    else return get_kth(rson,kth-s[lson].si-1); 
}      
int find(int x,int v) { 
    if(s[x].v==v) return x;   
    if(s[x].v<v)  return find(rson,v); 
    else return find(lson,v);   
}
void del(int v) {    
    int x=find(root,v); 
    splay(x,root);  
    int l=s[x].ch[0],r=s[x].ch[1];     
    while(s[l].ch[1]) l=s[l].ch[1]; 
    splay(l,s[x].ch[0]);  
    s[l].f=0,s[r].f=l,s[l].ch[1]=r,pushup(l);  
    s[x].ch[0]=s[x].ch[1]=s[x].f=0;    
    root=l;  
}   
int main() {  
    srand(time(NULL));              
    int m,x,y,z; 
    scanf("%d",&m);   
    ins(root,0,inf);   
    ins(root,0,-inf); 
    for(int i=1;i<=m;++i) { 
        int op; 
        scanf("%d",&op); 
        ++op;  
        if(op==1) { 
            scanf("%d",&x); 
            ins(root,0,x);   
            splay(tot,root);
        }
        if(op==2) { 
            scanf("%d",&x);     
            del(x); 
        }
        if(op==3) {  
            scanf("%d",&x); 
            int p=get_kth(root,x+1);  
            splay(p,root); 
            printf("%d\n",s[p].v); 
        }
        if(op==4) { 
            scanf("%d",&x); 
            printf("%d\n",get_rank(x));   
        }   
        if(op==5) { 
            scanf("%d",&x); 
            int p=get_pre(root,x); 
            splay(p,root); 
            if(s[p].v==-inf) printf("-1\n"); 
            else printf("%d\n",s[p].v); 
        }  
        if(op==6) { 
            scanf("%d",&x); 
            int p=get_aft(root,x); 
            splay(p,root); 
            if(s[p].v==inf) printf("-1\n"); 
            else printf("%d\n",s[p].v); 
        }
    }
    return 0;
}

  

矩阵乘法

需要注意的点:

1. 矩阵的初始化

2. 注意新矩阵的 $n,m$ 以及 3 重循环中上界 

#include <bits/stdc++.h> 
#define ll long long    
#define mod 1000000007      
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;
inline int ADD(int x,int y) { 
    return (x+y)>=mod?x+y-mod:x+y;  
}
struct M {    
    int c[501][501],n,m;  
    M() { memset(c,0,sizeof(c));}  
    int *operator[](int x) { return c[x]; }   
    M operator*(const M b) const {  
        M an;    
        an.n=n; 
        an.m=b.m;   
        for(int i=0;i<n;++i) {
            for(int j=0;j<b.m;++j) {
                for(int k=0;k<m;++k) {      
                    an.c[i][j]=ADD(an.c[i][j],(ll)c[i][k]*b.c[k][j]%mod);  
                }
            }
        }
        return an;   
    }
}A,B;  
int main() { 
    // setIO("input"); 
    int n,p,m;  
    scanf("%d%d%d",&n,&p,&m); 
    A.n=n,A.m=p;  
    for(int i=0;i<n;++i) {  
        for(int j=0;j<p;++j) scanf("%d",&A[i][j]),A[i][j]=ADD(A[i][j],mod);  
    }
    for(int i=0;i<p;++i) { 
        for(int j=0;j<m;++j) scanf("%d",&B[i][j]),B[i][j]=ADD(B[i][j],mod);   
    }
    B.n=p,B.m=m; 
    A=A*B;  
    for(int i=0;i<A.n;++i) { 
        for(int j=0;j<A.m;++j) printf("%d ",A[i][j]); 
        printf("\n"); 
    }
    return 0; 
}

  

多项式

快速傅里叶变换

使用 FFT 的场合比较少,一般都是要结合 MTT 之类的.     

对于复数 $(x,y)$,有 3 种运算:

$(x,y)+(x',y')=(x+x',y+y')$ 

$(x,y)-(x',y')=(x-x',y-y')$ 

$(x,y)*(x',y')=(x*x'-y*y',x*y'+y*x')$

#include <cstdio>  
#include <vector>  
#include <cmath>
#include <cstring>
#include <algorithm>  
#define ll long long  
#define db long double
#define pb push_back
#define N 1000007  
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;   
const db pi=acos(-1.0);   
struct cp { 
    db x,y;     
    cp(db a=0,db b=0) { x=a,y=b; } 
    cp operator+(const cp &b) const { return cp(x+b.x,y+b.y); }  
    cp operator-(const cp &b) const { return cp(x-b.x,y-b.y); }  
    cp operator*(const cp &b) const { return cp(x*b.x-y*b.y,x*b.y+y*b.x); }
}A[N<<2],B[N<<2];     
void FFT(cp *a,int len,int op) {  
    for(int i=0,k=0;i<len;++i) {    
        if(i>k) swap(a[i],a[k]);  
        for(int j=len>>1;(k^=j)<j;j>>=1); 
    }  
    for(int l=1;l<len;l<<=1) {  
        cp wn(cos(pi/l),op*sin(pi/l)),x,y;  
        for(int i=0;i<len;i+=l<<1) {  
            cp w(1,0);  
            for(int j=0;j<l;++j) { 
                x=a[i+j],y=w*a[i+j+l];   
                a[i+j]=x+y;  
                a[i+j+l]=x-y;  
                w=w*wn;   
            }
        }
    }  
    if(op==-1) {  
        for(int i=0;i<len;++i) a[i].x/=len;  
    }
}
int main() {
    // setIO("input");   
    int n,m,lim,x;    
    scanf("%d%d",&n,&m); 
    for(lim=1;lim<(n+m+1);lim<<=1); 
    for(int i=0;i<=n;++i) { 
        scanf("%d",&x),A[i].x=(db)x;   
    } 
    for(int i=0;i<=m;++i) {     
        scanf("%d",&x),B[i].x=(db)x;  
    } 
    FFT(A,lim,1),FFT(B,lim,1); 
    for(int i=0;i<lim;++i) A[i]=A[i]*B[i]; 
    FFT(A,lim,-1); 
    for(int i=0;i<=n+m;++i) { 
        printf("%d ",(int)(A[i].x+0.5)); 
    }
    return 0; 
}

  

任意模数NTT (MTT) 

当模数不能写成 $a \times 2^k+1$ 的时候就需要用到拆系数 FFT (MTT)了.     

令 $f(x)=wf_{0}(x)+f_{1}(x)$,$g(x)$ 同理.  

然后 $f*g=(wf_{0}+f_{1})(wg_{0}+g_{1})=(f_{0}g_{0})w^2+(f_{0}g_{1}+f_{1}g_{0})w+f_{1}g_{1}$.     

做 7 次 FFT 即可,这个 w 选 $2^{15}$ 就好了.  

#include <cstdio>  
#include <vector>  
#include <cmath>
#include <cstring>
#include <algorithm>  
#define ll long long  
#define db long double
#define pb push_back
#define N 100007  
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;   
const db pi=acos(-1.0);   
struct cp { 
    db x,y;     
    cp(db a=0,db b=0) { x=a,y=b; } 
    cp operator+(const cp &b) const { return cp(x+b.x,y+b.y); }  
    cp operator-(const cp &b) const { return cp(x-b.x,y-b.y); }  
    cp operator*(const cp &b) const { return cp(x*b.x-y*b.y,x*b.y+y*b.x); }
}f[2][N<<2],g[2][N<<2],ans[3][N<<2];  
int A[N],B[N];
int lim;   
ll C[N];     
void FFT(cp *a,int len,int op) {  
    for(int i=0,k=0;i<len;++i) {    
        if(i>k) swap(a[i],a[k]);  
        for(int j=len>>1;(k^=j)<j;j>>=1); 
    }  
    for(int l=1;l<len;l<<=1) {  
        cp wn(cos(pi/l),op*sin(pi/l)),x,y;  
        for(int i=0;i<len;i+=l<<1) {  
            cp w(1,0);  
            for(int j=0;j<l;++j) { 
                x=a[i+j],y=w*a[i+j+l];   
                a[i+j]=x+y;  
                a[i+j+l]=x-y;  
                w=w*wn;   
            }
        }
    }   
}
ll nor(db x,ll mod) { 
    return (ll)((ll)(x/lim+0.5)%mod+mod)%mod;  
}
void MTT(int *a,int n,int *b,int m,ll mod,ll *c) { 
    for(lim=1;lim<=(n+m);lim<<=1);         
    for(int i=0;i<=n;++i) {    
        f[0][i].x=a[i]>>15;  
        f[1][i].x=a[i]&0x7fff;      
    }  
    for(int i=0;i<=m;++i) { 
        g[0][i].x=b[i]>>15;  
        g[1][i].x=b[i]&0x7fff;   
    }
    
    FFT(f[0],lim,1),FFT(f[1],lim,1);  
    FFT(g[0],lim,1),FFT(g[1],lim,1);    
    for(int i=0;i<lim;++i) { 
        ans[0][i]=f[0][i]*g[0][i]; 
        ans[1][i]=f[0][i]*g[1][i]+f[1][i]*g[0][i];   
        ans[2][i]=f[1][i]*g[1][i]; 
    }   
    FFT(ans[0],lim,-1); 
    FFT(ans[1],lim,-1); 
    FFT(ans[2],lim,-1);        
    for(int i=0;i<=n+m;++i) {           
        ll x=(nor(ans[0][i].x,mod)<<30ll)%mod;    
        ll y=(nor(ans[1][i].x,mod)<<15ll)%mod;   
        ll z=nor(ans[2][i].x,mod)%mod;    
        c[i]=((x+y)%mod+z)%mod;  
    }
}
int main() {
    //setIO("input");       
    int n,m;     
    ll mod;     
    scanf("%d%d%lld",&n,&m,&mod); 
    for(int i=0;i<=n;++i) scanf("%d",&A[i]); 
    for(int i=0;i<=m;++i) scanf("%d",&B[i]);  
    MTT(A,n,B,m,mod,C); 
    for(int i=0;i<=n+m;++i) printf("%lld ",C[i]);  
    return 0; 
}

  

多项式求逆

公式 $B=2B'-AB'^2$   

这里注意复制 $A$ 数组的时候不要复制多了,否则会让前面的 B 多算.    

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define N 100008 
#define ll long long 
#define pb push_back
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin)
using namespace std; 
int A[N<<2],B[N<<2],f[N<<1],g[N<<1];      
int qpow(int x,int y) {
    int tmp=1;  
    for(;y;y>>=1,x=(ll)x*x%mod) {
        if(y&1) tmp=(ll)tmp*x%mod;  
    } 
    return tmp; 
} 
int get_inv(int x) {
    return qpow(x,mod-2); 
}
void NTT(int *a,int len,int op) {
    for(int i=0,k=0;i<len;++i) {
        if(i>k) swap(a[i],a[k]); 
        for(int j=len>>1;(k^=j)<j;j>>=1);  
    }
    for(int l=1;l<len;l<<=1) {
        int wn=qpow(3,(mod-1)/(l<<1));   
        if(op==-1) {
            wn=get_inv(wn); 
        }
        for(int i=0;i<len;i+=l<<1) {
            int w=1,x,y;  
            for(int j=0;j<l;++j) {
                x=a[i+j],y=(ll)w*a[i+j+l]%mod;  
                a[i+j]=(ll)(x+y)%mod;  
                a[i+j+l]=(ll)(x-y+mod)%mod;  
                w=(ll)w*wn%mod;   
            }
        }
    }
    if(op==-1) {
        int iv=get_inv(len); 
        for(int i=0;i<len;++i) {
            a[i]=(ll)a[i]*iv%mod;  
        }
    }    
}
void get_inv(int *a,int *b,int len,int la) {    
    if(len==1) {
        b[0]=get_inv(a[0]);   
        return;  
    }    
    get_inv(a,b,len>>1,la);   
    int l=len<<1;      
    for(int i=0;i<min(len,la);++i) A[i]=a[i];    
    for(int i=0;i<len>>1;++i) B[i]=b[i];   
    for(int i=min(len,la);i<l;++i) A[i]=0; 
    for(int i=len>>1;i<l;++i) B[i]=0; 
    NTT(A,l,1),NTT(B,l,1);   
    for(int i=0;i<l;++i) {
        A[i]=(ll)A[i]*B[i]%mod*B[i]%mod;  
    } 
    NTT(A,l,-1);   
    for(int i=0;i<len;++i) {
        b[i]=(ll)((ll)(b[i]<<1)%mod-A[i]+mod)%mod;   
    }
}
int main() { 
    // setIO("input");  
    int n,lim;  
    scanf("%d",&n);   
    for(int i=0;i<n;++i) {
        scanf("%d",&g[i]);  
    }
    for(lim=1;lim<n;lim<<=1);   
    get_inv(g,f,lim,n);   
    for(int i=0;i<n;++i) {
        printf("%d ",f[i]);   
    }
    return 0;
}

分治NTT

#include <cstdio>  
#include <cstring>
#include <algorithm>   
#define N 100008 
#define ll long long 
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std; 
int A[N<<2],B[N<<2],f[N],g[N];    
int qpow(int x,int y) { 
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod) { 
        if(y&1) tmp=(ll)tmp*x%mod; 
    } 
    return tmp; 
}  
int get_inv(int x) { 
    return qpow(x,mod-2); 
}  
void NTT(int *a,int len,int op) {  
    for(int i=0,k=0;i<len;++i) { 
        if(i>k) swap(a[i],a[k]); 
        for(int j=len>>1;(k^=j)<j;j>>=1); 
    }  
    for(int l=1;l<len;l<<=1) {  
        int wn=qpow(3,(mod-1)/(l<<1)),x,y,w;  
        if(op==-1) {    
            wn=get_inv(wn); 
        }   
        for(int i=0;i<len;i+=l<<1) { 
            w=1; 
            for(int j=0;j<l;++j) {  
                x=a[i+j],y=(ll)a[i+j+l]*w%mod;   
                a[i+j]=(ll)(x+y)%mod;   
                a[i+j+l]=(ll)(x-y+mod)%mod;  
                w=(ll)w*wn%mod;  
            }
        }   
    }     
    if(op==-1) {  
        int iv=get_inv(len);  
        for(int i=0;i<len;++i) { 
            a[i]=(ll)a[i]*iv%mod;  
        }
    }
}
void solve(int l,int r) { 
    if(l==r) { 
        return;   
    }   
    int mid=(l+r)>>1,lim,s1=0,s2=0; 
    solve(l,mid);      
    for(int i=l;i<=mid;++i) A[s1++]=f[i];    
    for(int i=0;i<=r-l;++i) B[s2++]=g[i];   
    for(lim=1;lim<(s1+s2);lim<<=1);    
    for(int i=s1;i<lim;++i) A[i]=0; 
    for(int i=s2;i<lim;++i) B[i]=0;  
    NTT(A,lim,1),NTT(B,lim,1);  
    for(int i=0;i<lim;++i)  A[i]=(ll)A[i]*B[i]%mod;  
    NTT(A,lim,-1);                       
    for(int i=mid+1;i<=r;++i) { 
        (f[i]+=A[i-l])%=mod;   
    }
    solve(mid+1,r); 
}
int main() { 
    // setIO("input"); 
    int n;  
    scanf("%d",&n);    
    for(int i=1;i<n;++i) { 
        scanf("%d",&g[i]); 
    }  
    f[0]=1;     
    solve(0,n-1);  
    for(int i=0;i<n;++i) { 
        printf("%d ",f[i]);  
    }
    return 0; 
}

  

posted @ 2020-07-21 14:48  EM-LGH  阅读(151)  评论(0编辑  收藏  举报