Codeforces 1500E - Subset Trick(线段树)
一道线段树的套路题(似乎 ycx 会做这道题?orzorz!!11)
首先考虑什么样的 \(x\) 是“不合适”的,我们假设 \(a_1<a_2<a_3<\cdots<a_{n-1}<a_n\),显然如果 \(\exist k\) 使得 \(a_1+a_2+\cdots+a_k\le x<a_{n-k+1}+a_{n-k+2}+\cdots+a_{n}\),那么 \(x\) 就是不合法的。也就是说所有 \(x\) 可取的数组成的集合为 \(n\) 个形如 \([a_1+a_2+\cdots+a_k,a_{n-k+1}+a_{n-k+2}+\cdots+a_{n})\) 的区间的并,而我们要求的就是这个并集的大小。
直接求并集似乎不是特别容易,因为可能会存在区间有交,贡献不太容易直接计算(当然如果硬要算也是可以算的,不过分类讨论应该会下面的解法稍微复杂些)。我们不妨考虑正难则反,拿总个数减去“合适”的 \(x\) 的个数,记 \(A=\sum\limits_{i=1}^na_i\),那么在区间 \([0,A)\) 中,\(x\) 为“合适”的值当且仅当 \(\exist k\in[0,n-1]\) 满足 \(a_{n-k+1}+a_{n-k+2}+\cdots+a_n\le x<a_1+a_2+\cdots+a_k+a_{k+1}\),而显然这样的 \(n\) 个区间是不存在交集的,因此“合适的” \(x\) 的个数可以直接通过将这 \(n\) 个区间的长度相加求出,即 \(\sum\limits_{k=0}^{n-1}\max((a_1+a_2+\cdots+a_k+a_{k+1})-(a_{n-k+1}+a_{n-k+2}+\cdots+a_n),0)\)(因为可能存在区间为空,因此要对 \(0\) 取 \(\max\))。我们记 \(f(k)=(a_1+a_2+\cdots+a_k+a_{k+1})-(a_{n-k+1}+a_{n-k+2}+\cdots+a_n)\),那么 \(ans=\sum\limits_{k=0}^{n-1}\max(f(k),0)\),这边带个 \(\max\),不好直接算出,因此我们不妨来探究下 \(f\) 函数有什么性质,很容易发现 \(f(k)\) 满足以下两个性质:
- \(f(k)=f(n-1-k)\)(提示:在左右两个括号同时加上 \(a_{k+2}+a_{k+3}+\cdots+a_{n-k}\))
- \(\forall x<y\le\dfrac{n-1}{2}\),\(f(x)>f(y)\)(提示:对于 \(k<\dfrac{n-1}{2}\),\(f(k)-f(k-1)=a_{k+1}-a_{n-k+1}<0\))
第一个性质告诉我们可以计算前一半(\(k\le\dfrac{n-1}{2}\))的答案,然后乘个 \(2\),如果 \(n\) 是奇数再减去 \(\max(f(\dfrac{n-1}{2}),0)\) 即可计算求得最终答案。第二个性质告诉我们满足 \(f(k)>0,k<\dfrac{n-1}{2}\) 的 \(k\) 是一段前缀,可二分找到最大的满足 \(f(l)>0\) 的 \(l\),计算 \(\sum\limits_{i=0}^lf(i)\),这样就可以不带 \(\max\),直接计算了。
最后是就是怎样计算 \(f(l)\) 和一段前缀的 \(f(l)\) 的问题。考虑将 \(a\) 离散化并建权值线段树,然后可在权值线段树上二分在对数时间内找到全局第 \(k\) 大的值以及前 \(k\) 大的和。而 \(\sum\limits_{i=0}^lf(i)=\sum\limits_{i=0}^l(\sum\limits_{j=1}^{i+1}a_j-\sum\limits_{j=n-i+1}^na_j)=\sum\limits_{j=1}^{l+1}a_j(l+1-j)-\sum\limits_{n-l+1}a_j(j-n+l)\),在线段树上额外维护 \(\sum\limits_{i}a_ii\) 即可,这东西也可在 \(\log n\) 时间内求出。时间复杂度 \(n\log^2n\),因为在找最大的 \(f(l)>0\) 的 \(l\) 过程中需二分,可能可以用线段树二分等技巧把那个 \(\log\) 去掉,不过没想了(
最后,有人可能会担心运算过程中爆 long long,事实上虽然 \(\sum\limits_{i}a_ii\) 可能会爆 long long,但是答案是在 long long 范围内的,因此可以理解为系统像字符串哈希一样自动对 \(2^{64}\) 取模,而最后运算结果 \(<2^{64}\),因此取模得到的结果就是它本身,故直接用 long long 存是不会出问题的。
const int MAXN=2e5;
int n,qu,cnt=0,num=0;
ll key[MAXN*2+5],uni[MAXN*2+5],a[MAXN+5],tot=0;
struct event{int opt;ll x;} q[MAXN+5];
struct node{int l,r,cnt;ll sum,sum_i;} s[MAXN*8+5];
void build(int k,int l,int r){
	s[k].l=l;s[k].r=r;if(l==r) return;
	int mid=l+r>>1;build(k<<1,l,mid);build(k<<1|1,mid+1,r);
}
void pushup(int k){
	s[k].cnt=s[k<<1].cnt+s[k<<1|1].cnt;s[k].sum=s[k<<1].sum+s[k<<1|1].sum;
	s[k].sum_i=s[k<<1].sum_i+s[k<<1|1].sum_i+s[k<<1|1].sum*s[k<<1].cnt;
}
void modify(int k,int p,int x){
	if(s[k].l==s[k].r){
		s[k].cnt+=x;s[k].sum+=x*uni[p];
		s[k].sum_i=1ll*s[k].cnt*(s[k].cnt+1)/2*uni[p];
		return;
	} int mid=s[k].l+s[k].r>>1;
	if(p<=mid) modify(k<<1,p,x);
	else modify(k<<1|1,p,x);
	pushup(k);
}
pair<ll,ll> getxth(int k,int x){
	if(s[k].l==s[k].r){return mp(1ll*x*uni[s[k].l],1ll*x*(x+1)/2*uni[s[k].l]);}
	if(x<=s[k<<1].cnt) return getxth(k<<1,x);
	else{
		pair<ll,ll> ss=getxth(k<<1|1,x-s[k<<1].cnt);
		return mp(ss.fi+s[k<<1].sum,ss.se+s[k<<1].sum_i+ss.fi*s[k<<1].cnt);
	}
}
ll getf(int k){
	pair<ll,ll> s1=getxth(1,k+1),s2=getxth(1,n-k);
	return s1.fi-(s[1].sum-s2.fi);
}
ll getans(){
	if(!n) return 0;
	int l=0,r=(n-1)/2,p=(n-1)/2+1;
	while(l<=r){
		int mid=l+r>>1;
		if(getf(mid)<=0) p=mid,r=mid-1;
		else l=mid+1;
	} p--;//printf("%d\n",p);
	pair<ll,ll> s1=getxth(1,p+1),s2=getxth(1,n-p);
	s2.fi=s[1].sum-s2.fi;s2.se=s[1].sum_i-s2.se;
//	printf("%lld %lld %lld %lld\n",s1.fi,s1.se,s2.fi,s2.se);
	ll sum=s1.fi*(p+2)-s1.se-s2.se+s2.fi*(n-p);
	sum=sum*2;if((n&1)) sum-=max(getf((n-1)/2),0ll);
	return tot-sum;
}
int main(){
	scanf("%d%d",&n,&qu);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]),key[++cnt]=a[i],tot+=a[i];
	for(int i=1;i<=qu;i++){scanf("%d%lld",&q[i].opt,&q[i].x);key[++cnt]=q[i].x;}
	sort(key+1,key+cnt+1);
	for(int i=1;i<=cnt;i++) if(key[i]^key[i-1]) uni[++num]=key[i];
	for(int i=1;i<=n;i++) a[i]=lower_bound(uni+1,uni+num+1,a[i])-uni;
	for(int i=1;i<=qu;i++) q[i].x=lower_bound(uni+1,uni+num+1,q[i].x)-uni;
	build(1,1,num);for(int i=1;i<=n;i++) modify(1,a[i],1);
	printf("%lld\n",getans());
	for(int i=1;i<=qu;i++){
		if(q[i].opt==1) modify(1,q[i].x,1),++n,tot+=uni[q[i].x];
		else modify(1,q[i].x,-1),--n,tot-=uni[q[i].x];
		printf("%lld\n",getans());
	}
	return 0;
}
 
                    
                
 
                
            
         浙公网安备 33010602011771号
浙公网安备 33010602011771号