Template(Updating)
1、Splay
(Tyvj1728)
#include <bits/stdc++.h> using namespace std; const int MAXN=1e5+10,INF=1<<30; int n,op,x,root,ch[MAXN][2],sz[MAXN],cnt[MAXN],val[MAXN],f[MAXN],tot=0; void pushup(int x) {sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];} void Rotate(int x) { int y=f[x],z=f[y],k=(x==ch[y][1]); ch[z][y==ch[z][1]]=x;f[x]=z; ch[y][k]=ch[x][k^1];f[ch[x][k^1]]=y; ch[x][k^1]=y;f[y]=x; pushup(x);pushup(y); } void Splay(int x,int up) { while(f[x]!=up) { int y=f[x],z=f[y]; if(z!=up) (ch[y][0]==x)^(ch[z][0]==y)?Rotate(x):Rotate(y); Rotate(x); } if(!up) root=x; } void Insert(int x) { int k=root,anc=0; while(k&&x!=val[k]) anc=k,k=ch[k][x>val[k]]; if(k) cnt[k]++; else { k=++tot; if(anc) ch[anc][x>val[anc]]=k; ch[k][0]=ch[k][1]=0; val[k]=x;cnt[k]=1;sz[k]=1;f[k]=anc; } Splay(k,0); } void Find(int x) { int k=root; if(!k) return; while(ch[k][x>val[k]]&&x!=val[k]) k=ch[k][x>val[k]]; Splay(k,0); } int Next(int x,int flag) { Find(x); int k=root; if((val[k]<x&&!flag)||(val[k]>x&&flag)) return k; k=ch[k][flag]; while(ch[k][flag^1]) k=ch[k][flag^1]; return k; } int Kth(int x) { int k=root; if(sz[k]<x) return 0; while(true) { if(x>sz[ch[k][0]]+cnt[k]) x-=sz[ch[k][0]]+cnt[k],k=ch[k][1]; else if(sz[ch[k][0]]>=x) k=ch[k][0]; else return val[k]; } } void Delete(int x) { int lst=Next(x,0),nxt=Next(x,1); Splay(lst,0);Splay(nxt,lst); if(cnt[ch[nxt][0]]>1) cnt[ch[nxt][0]]--,Splay(ch[nxt][0],0); else ch[nxt][0]=0; } int main() { scanf("%d",&n); Insert(INF);Insert(-INF);//记得先插入边界 for(int i=1;i<=n;i++) { scanf("%d%d",&op,&x); if(op==1) Insert(x); else if(op==2) Delete(x); else if(op==3) Find(x),printf("%d\n",sz[ch[root][0]]); else if(op==4) printf("%d\n",Kth(x+1)); else if(op==5) printf("%d\n",val[Next(x,0)]); else if(op==6) printf("%d\n",val[Next(x,1)]); } return 0; }
Tip:由于采取求前驱后继来删除节点的方式,要先$insert(INF)$和$insert(-INF)$
2、Treap
(Tyvj1728)
#include <bits/stdc++.h> using namespace std; const int MAXN=100005; const int INF=1<<27; int cnt=0,root,n,s[MAXN][2],siz[MAXN],val[MAXN],pri[MAXN]; void update(int i) {siz[i]=siz[s[i][0]]+siz[s[i][1]]+1;} void spin(int &i,int f) { int v=s[i][f^1]; s[i][f^1]=s[v][f];s[v][f]=i; update(i);update(i=v); } void ins(int &i,int k) { if(!i){i=++cnt;siz[i]=1;val[i]=k;pri[i]=rand();return;} siz[i]++; if(val[i]>=k){ins(s[i][0],k);if(pri[i]>pri[s[i][0]]) spin(i,1);} else{ins(s[i][1],k);if(pri[i]>pri[s[i][1]]) spin(i,0);} } int pre(int i,int k) { if(!i) return -INF; if(val[i]<k) return max(val[i],pre(s[i][1],k)); else return pre(s[i][0],k); } int nxt(int i,int k) { if(!i) return INF; if(val[i]>k) return min(val[i],nxt(s[i][0],k)); else return nxt(s[i][1],k); } int find_key(int i,int k) { if(siz[s[i][0]]==k-1) return val[i]; if(siz[s[i][0]]>=k) return find_key(s[i][0],k); return find_key(s[i][1],k-siz[s[i][0]]-1); } int find_rank(int i,int k) { if(!i) return 1; if(val[i]>=k) return find_rank(s[i][0],k); return siz[s[i][0]]+1+find_rank(s[i][1],k); } void del(int &i,int k) { if(val[i]==k) { if(s[i][0]*s[i][1]==0) { i=s[i][0]+s[i][1]; return; } if(pri[s[i][0]]<pri[s[i][1]]) spin(i,1),del(s[i][1],k); else spin(i,0),del(s[i][0],k); } else if(val[i]>=k) del(s[i][0],k); else del(s[i][1],k); update(i); } int main() { scanf("%d",&n);srand(time(NULL)); while(n--) { int op,x;scanf("%d%d",&op,&x); if(op==1) ins(root,x); else if(op==2) del(root,x); else if(op==3) printf("%d\n",find_rank(root,x)); else if(op==4) printf("%d\n",find_key(root,x)); else if(op==5) printf("%d\n",pre(root,x)); else printf("%d\n",nxt(root,x)); } return 0; }
Tip:注意递归边界
3、计算几何模板
我现在发现了我在计算几何方面就是个sillycross
在这里总结一下基本模块吧,
(使用complex<double>模板类)
点乘:
double dot(point a,point b){return real(a*conj(b));}
叉乘:
double det(point a,point b){return imag(a*conj(b));}
判断点x是否在线段[L,R]上:
bool on_seg(point x,point L,point R){return det(L-x,R-x)==0 && dot(L-x,R-x)<=0;}
判断三点是否在一条直线 + L,R是否在x的两侧
判断两线段是否相交(非严格):
bool seg_cross(point a,point b,point c,point d)
{
double s1=det(c-a,b-a)*det(b-a,d-a);
double s2=det(a-c,d-c)*det(d-c,b-c);
if(s1<0 || s2<0) return false;
if(s1==0 && s2==0) return on_seg(c,a,b) || on_seg(d,a,b);
return true;
}
对于每个点判断另一线段的两点是否在其两端 + 对一条线段的端点恰在另一线段上的特殊处理
求点x关于线段[A,B]的对称点:
point sym(point x,point A,point B){return 2*dot(A,B)/dot(B,B)*B-A+x;}
将线段[x,A]延长一倍,求出线段[x,x']的向量,再行加减即可
4、FFT
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=3e6+10; struct Complex { db x,y; Complex(db a=0,db b=0){x=a;y=b;} Complex operator + (const Complex& rhs) {return Complex(x+rhs.x,y+rhs.y);} Complex operator - (const Complex& rhs) {return Complex(x-rhs.x,y-rhs.y);} Complex operator * (const Complex& rhs) {return Complex(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);} }a[MAXN],b[MAXN]; int n,m,lmt=1,dgt,par[MAXN]; void FFT(Complex *a,int flag) { for(int i=0;i<lmt;i++) if(i<par[i]) swap(a[i],a[par[i]]); for(int len=1;len<lmt;len<<=1) { Complex unit(cos(M_PI/len),flag*sin(M_PI/len)); for(int st=0;st<lmt;st+=(len<<1)) { Complex w(1,0); for(int k=st;k<st+len;k++,w=w*unit) { Complex A=a[k],B=w*a[k+len]; a[k]=A+B;a[k+len]=A-B; } } } if(flag==-1) for(int i=0;i<=n+m;i++) a[i].x=floor(a[i].x/lmt+0.5); } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n;i++) scanf("%lf",&a[i].x); for(int i=0;i<=m;i++) scanf("%lf",&b[i].x); while(lmt<=n+m) lmt<<=1,dgt++; for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1)); FFT(a,1);FFT(b,1); for(int i=0;i<lmt;i++) a[i]=a[i]*b[i]; FFT(a,-1); for(int i=0;i<=n+m;i++) printf("%d ",(int)a[i].x); return 0; }
5、NTT
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=4e6+10,MOD=998244353; ll n,m,a[MAXN],b[MAXN],dgt,lmt=1,par[MAXN]; ll quick_pow(ll a,ll b) { ll ret=1; for(;b;b>>=1,a=a*a%MOD) if(b&1) ret=ret*a%MOD; return ret; } void FFT(ll *a,int flag) { for(int i=0;i<lmt;i++) if(i<par[i]) swap(a[i],a[par[i]]); for(int len=1;len<lmt;len<<=1) { ll unit=quick_pow(3,(MOD-1)/(len<<1)); if(flag==-1) unit=quick_pow(unit,MOD-2); for(int st=0;st<lmt;st+=(len<<1)) { ll w=1; for(int k=st;k<st+len;k++,w=w*unit%MOD) { ll A=a[k],B=w*a[k+len]%MOD; a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD; } } } } int main() { scanf("%lld%lld",&n,&m); for(int i=0;i<=n;i++) scanf("%lld",&a[i]); for(int i=0;i<=m;i++) scanf("%lld",&b[i]); while(lmt<=n+m) lmt<<=1,dgt++; for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1)); FFT(a,1);FFT(b,1); for(int i=0;i<lmt;i++) (a[i]*=b[i])%=MOD; FFT(a,-1); ll inv=quick_pow(lmt,MOD-2); for(int i=0;i<=n+m;i++) printf("%lld ",a[i]*inv%MOD); return 0; }
6、MTT
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=4e5+10; ll p[]={469762049,998244353,1004535809}; int n,m,MOD,F[MAXN],G[MAXN],dgt,lmt=1; ll a[3][MAXN],b[MAXN],res[MAXN],par[MAXN]; ll quickpow(ll a,ll b,ll MOD) { a%=MOD;ll ret=1; for(;b;b>>=1,a=a*a%MOD) if(b&1) ret=ret*a%MOD; return ret; } ll mul(ll a,ll b,ll MOD) { a=(a%MOD+MOD)%MOD; b=(b%MOD+MOD)%MOD;ll ret=0; for(;b;b>>=1,a=(a+a)%MOD) if(b&1) (ret+=a)%=MOD; return ret; } ll inv(ll a,ll MOD) {return quickpow(a,MOD-2,MOD);} void FFT(ll *a,int flag,ll MOD) { for(int i=0;i<lmt;i++) if(i<par[i]) swap(a[i],a[par[i]]); for(int len=1;len<lmt;len<<=1) { ll unit=quickpow(3,(MOD-1)/(len<<1),MOD); if(flag==-1) unit=inv(unit,MOD); for(int st=0;st<lmt;st+=(len<<1)) { ll w=1; for(int k=st;k<st+len;k++,w=w*unit%MOD) { ll A=a[k],B=w*a[k+len]%MOD; a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD; } } } if(flag==-1) { ll INV=inv(lmt,MOD); for(int i=0;i<lmt;i++) a[i]=a[i]*INV%MOD; } } void solve(ll *a,ll *b,ll MOD) { for(int i=0;i<=n;i++) a[i]=F[i]; for(int i=0;i<=m;i++) b[i]=G[i]; for(int i=m+1;i<lmt;i++) b[i]=0; FFT(a,1,MOD);FFT(b,1,MOD); for(int i=0;i<lmt;i++) a[i]=a[i]*b[i]%MOD; FFT(a,-1,MOD); } int main() { scanf("%d%d%d",&n,&m,&MOD); for(int i=0;i<=n;i++) scanf("%d",&F[i]); for(int i=0;i<=m;i++) scanf("%d",&G[i]); while(lmt<=n+m) lmt<<=1,dgt++; for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1)); for(int i=0;i<3;i++) solve(a[i],b,p[i]); for(int i=0;i<=n+m;i++) { ll M=p[0]*p[1]; ll A=(mul(a[0][i]*p[1],inv(p[1],p[0]),M)+ mul(a[1][i]*p[0],inv(p[0],p[1]),M))%M; ll K=mul(a[2][i]-A,inv(M,p[2]),p[2]); res[i]=(mul(K,M,MOD)+A%MOD)%MOD; } for(int i=0;i<=n+m;i++) printf("%lld ",res[i]); return 0; }
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=1e6+10; struct Complex { db x,y; Complex(db a=0,db b=0){x=a;y=b;} Complex operator +(const Complex& rhs) {return Complex(x+rhs.x,y+rhs.y);} Complex operator -(const Complex& rhs) {return Complex(x-rhs.x,y-rhs.y);} Complex operator *(const Complex& rhs) {return Complex(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);} }a[MAXN],b[MAXN],w[MAXN],t1[MAXN],t2[MAXN],t3[MAXN]; int n,m,MOD,lmt=1,dgt,par[MAXN];ll x,res[MAXN]; void FFT(Complex *a,int flag) { for(int i=0;i<lmt;i++) if(i<par[i]) swap(a[i],a[par[i]]); for(int len=1;len<lmt;len<<=1) for(int st=0;st<lmt;st+=(len<<1)) { int cur=0; for(int k=st;k<st+len;k++) { Complex A=a[k],B=w[cur]*a[k+len]; a[k]=A+B;a[k+len]=A-B; //预处理的写法 cur=(cur+flag*lmt/(len<<1)+lmt)&(lmt-1); } } if(flag==-1) for(int i=0;i<lmt;i++) a[i].x=floor(a[i].x/lmt+0.5); } void solve() { FFT(a,1);FFT(b,1); for(int i=0;i<lmt;i++) { Complex d1,d2,d3,d4; int j=(lmt-i)&(lmt-1); d1=(a[i]+Complex(a[j].x,-a[j].y))*Complex(0.5,0); d2=(a[i]-Complex(a[j].x,-a[j].y))*Complex(0,-0.5); d3=(b[i]+Complex(b[j].x,-b[j].y))*Complex(0.5,0); d4=(b[i]-Complex(b[j].x,-b[j].y))*Complex(0,-0.5); //必须先用临时变量存,因为后面还要用 t1[i]=d1*d3;t2[i]=d1*d4+d2*d3;t3[i]=d2*d4; } for(int i=0;i<lmt;i++) //充分利用虚部空间(可看成逆过程) b[i]=t2[i],a[i]=t1[i]+t3[i]*Complex(0,1); FFT(a,-1);FFT(b,-1); for(int i=0;i<lmt;i++) { ll k1=(ll)a[i].x%MOD,k2=(ll)b[i].x%MOD; ll k3=(ll)floor(a[i].y/lmt+0.5)%MOD; res[i]=((k3<<30)%MOD+(k2<<15)%MOD+k1)%MOD; } } int main() { scanf("%d%d%d",&n,&m,&MOD); for(int i=0;i<=n;i++) scanf("%lld",&x),a[i]=Complex(x&32767,x>>15); for(int i=0;i<=m;i++) scanf("%lld",&x),b[i]=Complex(x&32767,x>>15); while(lmt<=n+m) lmt<<=1,dgt++; for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1)); for(int i=0;i<lmt;i++) w[i]=Complex(cos(2*M_PI*i/lmt),sin(2*M_PI*i/lmt)); solve(); for(int i=0;i<=n+m;i++) printf("%lld ",res[i]); return 0; }
实测$myy$论文里拆系数优化至4次$DFT/IDFT$的方法比三模数$CRT$快7倍左右
7、SA
#include <bits/stdc++.h> using namespace std; #define X first #define Y second #define pb push_back #define debug(x) cerr<<#x<<"="<<x<<endl typedef double db; typedef long long ll; typedef pair<int,int> P; const int MAXN=1e6+10; char s[MAXN]; int len,lmt,cnt[MAXN],sa[MAXN],x[MAXN],y[MAXN],cur; void solve() { for(int i=1;i<=len;i++) cnt[x[i]=s[i]]++; for(int i=1;i<=lmt;i++) cnt[i]+=cnt[i-1]; for(int i=len;i>=1;i--) sa[cnt[x[i]]--]=i; for(int k=1;k<=len;k<<=1,lmt=cur) { cur=0; for(int i=len-k+1;i<=len;i++) y[++cur]=i; for(int i=1;i<=len;i++) if(sa[i]>k) y[++cur]=sa[i]-k; for(int i=1;i<=lmt;i++) cnt[i]=0; for(int i=1;i<=len;i++) cnt[x[i]]++; for(int i=1;i<=lmt;i++) cnt[i]+=cnt[i-1]; //一定要按第二关键字从大往小枚举! for(int i=len;i>=1;i--) sa[cnt[x[y[i]]]--]=y[i]; swap(x,y);cur=1;x[sa[1]]=1; for(int i=2;i<=len;i++) x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?cur:++cur; if(cur==len) break; } } int main() { scanf("%s",s+1); len=strlen(s+1);lmt=130; solve(); for(int i=1;i<=len;i++) printf("%d ",sa[i]); return 0; }