[NOI Online #2 提高组]子序列问题
description:
题目已经说得很清楚了
solution:
我们考虑记\(F_i=\sum_{k=1}^if(k,i)^2\)
考虑\(F_i\)如何由\(F_{i-1}\)递推过来
我们用\(pre_{a_i}\)表示\(a_i\)这个值上一次出现的位置(从未出现过则记为0)(下面简记为\(j\))
根据定义,
\(F_i=\sum_{k=1}^if(k,i)^2\)
\(F_{i-1}=\sum_{k=1}^{i-1}f(k,i-1)^2\)
同时对于\(\forall k\le j\)有\(f_{k,i}==f_{k,i-1}\)
对于\(\forall k> j\)有\(f_{k,i}==f_{k,i-1}+1\)其中\(f_{i,i-1}\)定义为\(0\)
因此\(F_i-F_{i-1}=\sum_{k=j+1}^i((f_{k,i-1}+1)^2-f_{k,i-1}^2)\)
稍微推下柿子可得\(原式=(i-j)+2\sum_{k=j+1}^{i-1}f_{k,i-1}\)
然后就可以发现这就是一个区间加&区间求和的东西,直接上线段树来维护就可以了
最后答案就是\(ans=\sum_{i=1}^nF_i\)
code:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+5,mod=1e9+7;
int n,a[N],b[N],pre[N],f[N];
inline int read()
{
int s=0,w=1; char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')w=-1;
for(;isdigit(ch);ch=getchar())s=(s<<1)+(s<<3)+(ch^48);
return s*w;
}
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
struct SGT
{
int tag[N<<2],s[N<<2];
#define lc rt<<1
#define rc rt<<1|1
inline void up(int rt){s[rt]=add(s[lc],s[rc]);}
inline void down(int rt,int l,int r)
{
int &tg=tag[rt];
if(!tg)return;
int mid=(l+r)>>1,lenl=mid-l+1,lenr=r-mid;
s[lc]=add(s[lc],1ll*lenl*tg%mod),s[rc]=add(s[rc],1ll*lenr*tg%mod);
tag[lc]=add(tag[lc],tg),tag[rc]=add(tag[rc],tg);
tg=0;
}
void upd(int rt,int ll,int rr,int l=1,int r=n)
{
if(ll<=l&&r<=rr){s[rt]=add(s[rt],r-l+1),++tag[rt];return;}
down(rt,l,r);
int mid=(l+r)>>1;
if(ll<=mid)upd(lc,ll,rr,l,mid);
if(mid<rr)upd(rc,ll,rr,mid+1,r);
up(rt);
}
int query(int rt,int ll,int rr,int l=1,int r=n)
{
if(ll<=l&&r<=rr)return s[rt];
down(rt,l,r);
int mid=(l+r)>>1,anss=0;
if(ll<=mid)anss=add(anss,query(lc,ll,rr,l,mid));
if(mid<rr)anss=add(anss,query(rc,ll,rr,mid+1,r));
up(rt);return anss;
}
#undef lc
#undef rc
}T;
int main()
{
n=read();
for(int i=1;i<=n;++i)a[i]=read(),b[i]=a[i];
sort(b+1,b+n+1);
int tot=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;++i)
a[i]=lower_bound(b+1,b+tot+1,a[i])-b;//离散化
for(int i=1;i<=n;++i)
{
int j=pre[a[i]];
f[i]=add(f[i-1],i-j);
if(j+1<=i-1)f[i]=add(f[i],2*T.query(1,j+1,i-1)%mod);
T.upd(1,j+1,i);pre[a[i]]=i;
}
int ans=0;
for(int i=1;i<=n;++i)ans=add(ans,f[i]);
printf("%d\n",ans);
return 0;
}
NO PAIN NO GAIN