【BZOJ】3173: [Tjoi2013]最长上升子序列(树状数组)
【题意】给定ai,将1~n从小到大插入到第ai个数字之后,求每次插入后的LIS长度。
【算法】树状数组||平衡树
【题解】
这是树状数组的一个用法:O(n log n)寻找前缀和为k的最小位置。(当数列中只有0和1时,转化为求对应排名的数字,就是简单代替平衡树)
根据树状数组的二进制分组规律,从大到小进行倍增,可以发现每次需要加的Σa[i],i∈(now,now+(1<<i)]刚好就是c[now+(1<<i)]。
文字表述就是,跳跃到的位置的c[]刚好表示中间跳跃的数字和,这是树状数组二进制分组规律的特殊性质。
还需要注意的是,实际上需要寻找前缀和<k的最大位置,最后+1。(否则会被目标数字后面的0干扰)
利用上述的方法,初始树状数组全部置为1,然后从n到1倒着寻找并删除,就可以得到每个数字在最终序列中的位置。
这道题由于从小到大插入,可以发现将所有数字全部插入也不会破坏过程中需要的LIS(只会在最后增长)。
那么第i个答案就是以数字1~i结尾的LIS的最长长度。
所以令f[i]表示最终序列中以数字 i 结尾的LIS,则第i个答案就是min(f[j]),j=1~i。(是数字i,不是第i个位置)
求解f[i]只需在O(n log n)求解整个最终序列的LIS的过程中求出即可。
总复杂度O(n log n)。
最后,代码中运用的线性构造树状数组,原理十分简单。
首先要求1~n都有数字(0也行),然后每个数加到自身c[i]+=a[i],再贡献一下父亲c[i+lowbit(i)]+=c[i]就可以了。
#include<cstdio> #include<cstring> #include<algorithm> #include<cctype> #define lowbit(x) (x&-x) using namespace std; const int maxn=100010; int a[maxn],b[maxn],c[maxn],g[maxn],anss[maxn],n; int read(){ char c;int s=0,t=1; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } void insert(int x,int k){for(int i=x;i<=n;i+=lowbit(i))c[i]+=k;} int find(int x){ int now=0,ans=0; for(int i=20;i>=0;i--){ now+=(1<<i); if(now<n&&ans+c[now]<x)ans+=c[now];//< near else now-=(1<<i); } now++; insert(now,-1); return now; } int max(int a,int b){return a<b?b:a;} int main(){ n=read(); for(int i=1;i<=n;i++){ a[i]=read(); c[i]++;c[i+lowbit(i)]+=c[i]; } for(int i=n;i>=1;i--)b[find(a[i]+1)]=i; int m=0; for(int i=1;i<=n;i++){ int s=lower_bound(g+1,g+m+1,b[i])-g; if(s>m)g[++m]=b[i];else g[s]=b[i]; anss[b[i]]=s; } for(int i=1;i<=n;i++){ anss[i]=max(anss[i-1],anss[i]); printf("%d\n",anss[i]); } return 0; }
补充平衡树写法(fhq-treap)。
每个点记录以这个点结尾的LIS,然后插入平衡树中,平衡树维护区间max值。
怎么得到以每个点结尾的LIS?因为当前加入的点不可能改变之前的点的LIS,所以只需要区间查询该点插入位置之前的max+1就是以这个点结尾的LIS。
#include<cstdio> #include<cstring> #include<algorithm> #include<cctype> using namespace std; const int maxn=100010; struct cyc{int l,r,rnd,num,mx,sz;}t[maxn]; int root,n; int read(){ char c;int s=0,t=1; while(!isdigit(c=getchar()))if(c=='-')t=-1; do{s=s*10+c-'0';}while(isdigit(c=getchar())); return s*t; } void up(int k){ t[k].sz=t[t[k].l].sz+t[t[k].r].sz+1; t[k].mx=max(t[k].num,max(t[t[k].l].mx,t[t[k].r].mx)); } void split(int k,int &l,int &r,int x){ if(!k)return void(l=r=0); if(x<t[t[k].l].sz+1){ r=k; split(t[k].l,l,t[k].l,x); } else{ l=k; split(t[k].r,t[k].r,r,x-t[t[k].l].sz-1); } up(k); } int merge(int a,int b){ if(!a||!b)return a^b; if(t[a].rnd<t[b].rnd){ t[a].r=merge(t[a].r,b); up(a); return a; } else{ t[b].l=merge(a,t[b].l); up(b); return b; } } void insert(int k,int x){ int a,b; split(root,a,b,x); t[k]=(cyc){0,0,rand(),t[a].mx+1,t[a].mx+1,1}; root=merge(a,k); root=merge(root,b); printf("%d\n",t[root].mx); } int main(){ n=read();root=0; for(int i=1;i<=n;i++)insert(i,read()); return 0; }