AtCoder Beginner Contest 397-f
考虑\(O({n^3})\)的做法。因为序列中的数不会改变,所以我们可以枚举第一个分割点\(i\)和第二个分割点\(j\),然后用\(O(n)\)的复杂度算出\(1\)到\(i\),\(i+1\)到\(j\),\(j+1\)到\(n\)三个范围中不同数个数之和,最终取最大值就是答案。
考虑优化,我们发现我们可以用一个前缀和数组和一个后缀和数组来预处理出原序列\(a\)的答案,而\(i+1\)到\(j\)这一区间无法通过预处理算出答案,我们只能每次再花费\(O(n)\)的复杂度来算出这一区间的答案(具体过程如下:1.枚举第一分割点\(i\) 2.枚举\(i+1\)到\(n\)这一区间算出这一区间的前缀和 3.枚举第二分割点\(j\)算出三个区间的答案(因为\(j+1\)到\(n\)这个区间的答案就是原序列的后缀和,不需要再次统计)复杂度为\(O(n)*O(n)=O({n^2})\))
最终优化,我们发现“枚举\(i+1\)到\(n\)这一区间算出这一区间的前缀和”这一步太浪费时间,所以我们要进一步优化这一步。因为我们要维护的是前缀和,所以每当在第一个数的前面插入一个数\(a_i\),\(i\)到\({x_i}-1\)(\(x_i\)代表着\(i\)后面第一个值与\(a_i\)相等的位置,没有则是\(n+1\))的前缀和都要加一,结合我们要查询的是区间的最大值,我们不难发现我们可以用线段树(区间修改,区间查询最大值)来优化算法,这样我们就把算法成功的优化为了\(O(nlog_n)\)的复杂度,完美地通过了\(3*{10^5}\)的数据。
CODE
#include<iostream>
#include<cstring>
using namespace std;
int n,a[300010],tj[300010],jt[300010],f[300010],x[300010];
struct node
{
int mxa,lazy;
}s[1200010];
void pushdown(int x,int l,int r)
{
if(s[x].lazy==0) return;
int mid=(l+r)>>1;
s[x<<1].mxa+=s[x].lazy;
s[x<<1].lazy+=s[x].lazy;
s[x<<1|1].mxa+=s[x].lazy;
s[x<<1|1].lazy+=s[x].lazy;
s[x].lazy=0;
}
int query(int u,int l,int r,int ll,int rr)
{
if(l>=ll && r<=rr)
{
return s[u].mxa;
}
pushdown(u,l,r);
int mid=(l+r)>>1;
if(rr<=mid) return query(u<<1,l,mid,ll,rr);
if(ll>mid) return query(u<<1|1,mid+1,r,ll,rr);
return max(query(u<<1,l,mid,ll,mid),query(u<<1|1,mid+1,r,mid+1,rr));
}
void add(int u,int l,int r,int ll,int rr,int z)
{
if(l>=ll && r<=rr)
{
s[u].mxa+=z;
s[u].lazy+=z;
return;
}
int mid=(l+r)>>1;
pushdown(u,l,r);
if(rr<=mid) add(u<<1,l,mid,ll,rr,z);
else if(ll>mid) add(u<<1|1,mid+1,r,ll,rr,z);
else
{
add(u<<1,l,mid,ll,mid,z);
add(u<<1|1,mid+1,r,mid+1,rr,z);
}
s[u].mxa=max(s[u<<1].mxa,s[u<<1|1].mxa);
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
memset(f,0,sizeof(f));
for(int i=1;i<=n;i++)
{
if(!f[a[i]])
{
tj[i]=tj[i-1]+1;
f[a[i]]=i;
}
else tj[i]=tj[i-1];
}
memset(f,0,sizeof(f));
for(int i=n;i>=1;i--)
{
if(!f[a[i]])
{
jt[i]=jt[i+1]+1;
f[a[i]]=i;
}
else jt[i]=jt[i+1],x[i]=f[a[i]],f[a[i]]=i;
add(1,1,n,i,i,jt[i+1]);
}
int ans=0;
for(int i=n-2;i>=1;i--)
{
add(1,1,n,i+1,(x[i+1]==0?n:x[i+1]-1),1);
ans=max(ans,tj[i]+query(1,1,n,i+1,n-1));
}
cout<<ans;
return 0;
}