2024.2.25模拟赛T3题解

题目

推出dp柿子之后,枚举 \(i\) 的时候用线段树维护 \(1-i\)\(mex\) 段,对于每一段,分别使用线段树套李超树维护,对于每个 \(mex\) 再次使用线段树套李超树维护即可

code

#include<bits/stdc++.h>
using namespace std;
#define N 600005
#define int long long
int n,m;
const int inf=1e18;
int a[N],s[N],f[N],c[N*4];
void upd(int p,int l,int r,int x,int y){
	if(l==r){
		c[p]=y;
		return ;
	}
	int mid=(l+r)>>1;
	if(x<=mid) upd(p*2,l,mid,x,y);
	else upd(p*2+1,mid+1,r,x,y);
	c[p]=min(c[p*2],c[p*2+1]);
}
int find(int p,int l,int r,int x){
	if(p==1) x--;
	if(x<l) return n;
	if(r<=x) return c[p];
	int mid=(l+r)>>1;
	if(x<=mid) return find(p*2,l,mid,x);
	else return min(c[p*2],find(p*2+1,mid+1,r,x));
}
int qry(int p,int l,int r,int x){
	if(l==r) return l;
	int mid=(l+r)>>1;
	if(c[p*2]<x) return qry(p*2,l,mid,x);
	else return qry(p*2+1,mid+1,r,x);
}
struct STS{
	int n,tot=0,top=0;
	int k[N],b[N]={-inf},c[N*20],ls[N*20],rs[N*20],rt[N*4];
	int cal(int x,int p){
		return k[p]*x+b[p];
	}
	void add(int &p,int l,int r,int x){
		if(!p) p=++top;
		if(l==r){
			if(cal(l,x)>cal(l,c[p])) c[p]=x;
			return ;
		}
		int mid=(l+r)>>1;
		if(cal(mid,x)>cal(mid,c[p])) swap(x,c[p]);
		if(cal(l,x)>cal(l,c[p])) add(ls[p],l,mid,x);
		if(cal(r,x)>cal(r,c[p])) add(rs[p],mid+1,r,x);
	}
	int find(int p,int l,int r,int x){
		if(!p) return -inf;
		if(l==r) return cal(x,c[p]);
		int mid=(l+r)>>1,val=cal(x,c[p]);
		if(x<=mid) val=max(val,find(ls[p],l,mid,x));
		else val=max(val,find(rs[p],mid+1,r,x));
		return val;
	}
	int qry(int p,int l,int r,int x,int y,int z){
		if(x>y) return -inf;
		if(x<=l&&r<=y) return find(rt[p],0,n,z);
		int mid=(l+r)>>1,val=-inf;
		if(x<=mid) val=max(val,qry(p*2,l,mid,x,y,z));
		if(mid<y) val=max(val,qry(p*2+1,mid+1,r,x,y,z));
		return val;
	}
	void upd(int p,int l,int r,int x,int y){
		add(rt[p],0,n,y);
		if(l==r) return ;
		int mid=(l+r)>>1;
		if(x<=mid) upd(p*2,l,mid,x,y);
		else upd(p*2+1,mid+1,r,x,y);
	}
	void ins(int kk,int bb,int x){
		tot++;k[tot]=kk,b[tot]=bb;
		upd(1,1,::n,x,tot);
	}
}T1,T2;
signed main(){
	scanf("%lld%lld",&n,&m);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]),s[i]=s[i-1]+a[i];
	T1.n=n,T2.n=s[n];
	for(int i=1;i<=n;i++){
		T1.ins(-s[i-1],f[i-1],i),T2.ins(0,f[i-1],i);
		int l=min(i,find(1,0,n,a[i]+1)),r=min(i,find(1,0,n,a[i]));
		upd(1,0,n,a[i],i);
		while(l<r){
			int col=qry(1,0,n,r);
			int nr=max(l,find(1,0,n,col+1));
		//	printf("i l r col nr: %lld %lld %lld %lld %lld\n",i,l,r,col,nr);
			int num=T1.qry(1,1,n,nr+1,r,col);
			T2.ins(col,num,nr+1);r=nr;
		}
		int dn=max(1ll,i-m+1);
		int col=qry(1,0,n,dn);
		r=min(i,find(1,0,n,col));
		f[i]=max(T2.qry(1,1,n,dn,i,s[i]),s[i]*col+T1.qry(1,1,n,dn,r,col));
	//	printf("%lld: %lld %lld %lld\n",i,f[i],T2.qry(1,1,n,dn,i,s[i]),s[i]*col+T1.qry(1,1,n,dn,r,col));
	}
	printf("%lld\n",f[n]);
	return 0;
}
posted @ 2024-02-26 21:30  hubingshan  阅读(53)  评论(0)    收藏  举报