题解 Nasty Donchik 一道数据结构题

题目大意

题目链接 比赛链接

给定一个长度为\(n\)的序列\(a_1,a_2\dots,a_n\)。保证\(\forall i:1\leq a_i\leq n\)。请你求出,序列里有多少三元组\((i,j,k)\),满足\(a[i,j]\)里的所有数,都在\(a[j+1,k]\)里出现过;且\(a[j+1,k]\)里所有数,都在\(a[i,j]\)里出现过。

\(n\leq 2\times 10^5\)

本题题解

枚举\(k\)。对每个\(j\),维护使三元组\((i,j,k)\)合法的最小的和最大的\(i\),分别记为\(\text{mini}[j],\text{maxi}[j]\)。那么,当前\(k\)的三元组数量就是:\(\sum_{j=1}^{k-1}(\text{maxi}[j]-\text{mini}[j]+1)\)。考虑分别计算\(\text{maxi}\)的和和\(\text{mini}\)的和。

记每个位置\(t\)上的数上一次和下一次出现的位置分别为\(\text{pre}[t]\)\(\text{nxt}[t]\),特别地,如果前面/后面没有相同的数,则\(\text{pre}[t]=0\)\(\text{nxt}[t]=n+1\)。那么,我们发现,三元组\((i,j,k)\)合法的充分必要条件是:\(\max_{t=i}^{j}(\text{nxt}[t])\leq k\),且\(\min_{t=j+1}^{k}(\text{pre}[t])\geq i\)

由此可知,\(\text{maxi}[j]\)就是满足\(\min_{t=j+1}^{k}(\text{pre}[t])\geq i\)的最大的\(i\)\(\text{mini}[j]\)就是满足\(\max_{t=i}^{j}(\text{nxt[}t])\leq k\)的最小的\(i\)

\(\text{maxi}\)比较好维护,他就等于\(\min_{t=j+1}^{k}(\text{pre}[j])\)。当从\(k-1\)变到\(k\)时,我们让所有\(j\in[1,k-1]\)\(\text{maxi}[j]\)\(\text{pre}[k]\)\(\min\)即可。

考虑\(\text{mini}\)。我们称\(\text{nxt}[t]>k\)的位置为不合法的,其他位置为合法的。那么对于每个\(j\)\(\text{mini}[j]\)就相当于\(j\)前面、最靠近\(j\)的那个不合法的位置\(+1\)。特别地,如果\(j\)本身就不合法,我们认为\(\text{mini}[j]=j+1\)。从\(k-1\)变到\(k\),会使得所有\(\text{nxt}[t]=k\)的位置,从不合法变成合法。相当于把两段\(\text{mini}\)的区间“合并”起来(令后一段区间的值等于前一段区间的值)。而\(\text{nxt}[t]=k\)的位置最多只有一个:就是\(\text{pre}[k]\)。所以每次对一段区间执行区间覆盖(或者区间取\(\min\))即可(事实上因为\(\text{maxi}\)要支持的是区间取\(\min\),所以都用区间取\(\min\)反而更好写)。

还有一个要注意的点是,我们要始终保证,\(\text{mini}[j]\leq\text{maxi}[j]+1\),所以对\(\text{maxi}\)\(\min\)的时候,要对\(\text{mini}\)做一样的操作。

总结来说,需要支持区间对一个数取\(\min\),区间求和,可以用吉老师线段树实现。另外,我们还要对一个位置求它前面、最靠近它的不合法的位置,同时要支持单点修改(把某个位置从不合法变为合法),这个可以用线段上二分实现。

时间复杂度\(O(n\log n)\)

参考代码:

#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fst first
#define scd second

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

namespace Fread{
	const int MAXN=1<<20;
	char buf[MAXN],*S,*T;
	inline char getchar(){
		if(S==T){
			T=(S=buf)+fread(buf,1,MAXN,stdin);
			if(S==T)return EOF;
		}
		return *S++;
	}
}
#ifdef ONLINE_JUDGE
	#define getchar Fread::getchar
#endif
inline int read(){
	int f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline ll readll(){
	ll f=1,x=0;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
/*  ------  by:duyi  ------  */ // dysyn1314
const int MAXN=2e5;
int n;
/*
struct Baoli{
	int a[MAXN+5],val[MAXN+5],val2[MAXN+5];
	int get_nxt0(int p){
		for(int i=p;i<=n+1;++i)if(a[i]==0)return i;
		throw;
	}
	int get_pre0(int p){
		for(int i=p;i>=0;--i)if(a[i]==0)return i;
		throw;
	}
	void set1(int p){
		a[p]=1;
	}
	void init(){
		for(int i=1;i<=n;++i)val[i]=val2[i]=i;
	}
	void modify_min_mxi(int l,int r,int x){
		for(int i=l;i<=r;++i)val[i]=min(val[i],x);
	}
	void modify_min_mni(int l,int r,int x){
		for(int i=l;i<=r;++i)val2[i]=min(val2[i],x);
	}
	int get_sum_mxi(){
		int res=0;
		for(int i=1;i<=n;++i)res+=val[i]*a[i];
		return res;
	}
	int get_sum_mni(){
		int res=0;
		for(int i=1;i<=n;++i)res+=val2[i]*a[i];
		return res;
	}
}T;
*/
class SegmentTree{
private:
	int sz[MAXN*4+5],mx[2][MAXN*4+5],se[2][MAXN*4+5],ct[2][MAXN*4+5];
	ll sum[2][MAXN*4+5];
	void _pu(int p,int *mx,int *se,int *ct,ll *sum){
		sum[p]=sum[p<<1]+sum[p<<1|1];
		if(mx[p<<1]>mx[p<<1|1]){
			mx[p]=mx[p<<1];
			se[p]=max(se[p<<1],mx[p<<1|1]);
			ct[p]=ct[p<<1];
		}
		else if(mx[p<<1]<mx[p<<1|1]){
			mx[p]=mx[p<<1|1];
			se[p]=max(mx[p<<1],se[p<<1|1]);
			ct[p]=ct[p<<1|1];
		}
		else{
			mx[p]=mx[p<<1];
			se[p]=max(se[p<<1],se[p<<1|1]);
			ct[p]=ct[p<<1]+ct[p<<1|1];
		}
	}
	void push_up(int p){
		sz[p]=sz[p<<1]+sz[p<<1|1];
		_pu(p,mx[0],se[0],ct[0],sum[0]);
		_pu(p,mx[1],se[1],ct[1],sum[1]);
	}
	void _pd(int p,int *mx,int *ct,ll *sum){
		if(mx[p]<mx[p<<1]){
			sum[p<<1]-=(ll)ct[p<<1]*(mx[p<<1]-mx[p]);
			mx[p<<1]=mx[p];
		}
		if(mx[p]<mx[p<<1|1]){
			sum[p<<1|1]-=(ll)ct[p<<1|1]*(mx[p<<1|1]-mx[p]);
			mx[p<<1|1]=mx[p];
		}
	}
	void push_down(int p){
		_pd(p,mx[0],ct[0],sum[0]);
		_pd(p,mx[1],ct[1],sum[1]);
	}
	void build(int p,int l,int r){
		if(l==r){
			mx[0][p]=mx[1][p]=l;
			se[0][p]=se[1][p]=-1;
			return;
		}
		int mid=(l+r)>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		push_up(p);
	}
	void modify1(int p,int l,int r,int pos){
		if(l==r){
			sz[p]=1;
			ct[0][p]=ct[1][p]=1;
			sum[0][p]=mx[0][p];
			sum[1][p]=mx[1][p];
			return;
		}
		push_down(p);
		int mid=(l+r)>>1;
		if(pos<=mid)modify1(p<<1,l,mid,pos);
		else modify1(p<<1|1,mid+1,r,pos);
		push_up(p);
	}
	int __first0(int p,int l,int r){
		if(l==r){assert(sz[p]==0);return l;}
		push_down(p);
		int mid=(l+r)>>1;
		if(sz[p<<1]<mid-l+1)return __first0(p<<1,l,mid);
		else return __first0(p<<1|1,mid+1,r);
	}
	int _nxt0(int p,int l,int r,int ql,int qr){
		if(ql<=l && qr>=r){
			if(sz[p]==r-l+1)return n+1;
			else return __first0(p,l,r);
		}
		push_down(p);
		int mid=(l+r)>>1,res=n+1;
		if(ql<=mid&&sz[p<<1]<mid-l+1)res=_nxt0(p<<1,l,mid,ql,qr);
		if(res!=n+1)return res;
		if(qr>mid&&sz[p<<1|1]<r-mid)return _nxt0(p<<1|1,mid+1,r,ql,qr);
		else return n+1;
	}
	int __last0(int p,int l,int r){
		if(l==r){assert(sz[p]==0);return l;}
		push_down(p);
		int mid=(l+r)>>1;
		if(sz[p<<1|1]<r-mid)return __last0(p<<1|1,mid+1,r);
		else return __last0(p<<1,l,mid);
	}
	int _pre0(int p,int l,int r,int ql,int qr){
		if(ql<=l && qr>=r){
			if(sz[p]==r-l+1)return 0;
			else return __last0(p,l,r);
		}
		push_down(p);
		int mid=(l+r)>>1,res=0;
		if(qr>mid&&sz[p<<1|1]<r-mid)res=_pre0(p<<1|1,mid+1,r,ql,qr);
		if(res)return res;
		if(ql<=mid&&sz[p<<1]<mid-l+1)return _pre0(p<<1,l,mid,ql,qr);
		else return 0;
	}
	void modify2(int p,int l,int r,int ql,int qr,int x,int t){
		//区间对x取min
		if(x>=mx[t][p])return;
		if(ql<=l && qr>=r && se[t][p]<x){
			sum[t][p]-=(ll)ct[t][p]*(mx[t][p]-x);
			mx[t][p]=x;
			return;
		}
		push_down(p);
		int mid=(l+r)>>1;
		if(ql<=mid)modify2(p<<1,l,mid,ql,qr,x,t);
		if(qr>mid)modify2(p<<1|1,mid+1,r,ql,qr,x,t);
		push_up(p);
	}
public:
	//mxi tree0
	//mni tree1
	void set1(int p){modify1(1,1,n,p);}
	int get_nxt0(int p){
		if(p>n)return n+1;
		if(p<1)return 0;
		return _nxt0(1,1,n,p,n);
	}
	int get_pre0(int p){
		if(p>n)return n+1;
		if(p<1)return 0;
		return _pre0(1,1,n,1,p);
	}
	void modify_min_mxi(int l,int r,int x){
		if(l>r)return;
		modify2(1,1,n,l,r,x,0);
	}
	void modify_min_mni(int l,int r,int x){
		if(l>r)return;
		modify2(1,1,n,l,r,x,1);
	}
	ll get_sum_mxi(){return sum[0][1];}
	ll get_sum_mni(){return sum[1][1];}
	void init(){build(1,1,n);}
}T;
int a[MAXN+5],nxt[MAXN+5],pre[MAXN+5],pos[MAXN+5];
int main(){
	n=read();
	for(int i=1;i<=n;++i){a[i]=read();pre[i]=pos[a[i]];pos[a[i]]=i;}
	for(int i=1;i<=n;++i)pos[i]=n+1;
	for(int i=n;i>=1;--i){nxt[i]=pos[a[i]];pos[a[i]]=i;}
	T.init();
	ll ans=0;
	for(int k=1;k<=n;++k){
		if(pre[k]){
			int x=T.get_nxt0(pre[k]+1)-1;
			//cout<<"* "<<x<<" "<<T.get_pre0(pre[k]-1)<<endl;
			T.modify_min_mni(pre[k],x,T.get_pre0(pre[k]-1));
			T.set1(pre[k]);
		}
		T.modify_min_mxi(1,k-1,pre[k]);
		T.modify_min_mni(1,k-1,pre[k]);
		ans+=T.get_sum_mxi()-T.get_sum_mni();
	}
	cout<<ans<<endl;
	return 0;
}
posted @ 2020-05-22 15:39  duyiblue  阅读(553)  评论(2编辑  收藏  举报