[bzoj4709][柠檬]

题目链接

思路

首先,最优秀的分法一定是每段两端都是这一段中最多的那个,否则可以把不是的那个踢出去单独成段肯定会更优秀。然后就成了将这个序列分段,保证每段两端元素相同的最大收益和。
用a[i]记录第i个位置上的数,用s[i]记录前i个元素中a[i]出现的次数。f[i]表示以前i个数的最大收益。
首先考虑\(n^2\)的dp。明显\(f[i]=max\{f[j]+a[i]*(s[i]-s[j]+1)^2\} (a[i]==a[j])\)

	for(int i=1;i<=n;++i)
		for(int j=1;j<=i;++j)
			if(a[i]==a[j])
				f[i]=max(f[i],f[j-1]+a[i]*(s[i]-s[j]+1)*(s[i]-s[j]+1));

然后可发现,在上面的式子中,s数组是单增的,f数组也是单增的。如果知道了两个位置x和y(x<y)。通过二分,可以找到一个now使得当以后的某个位置pos的s[pos]>now之后的所有位置用x转移会比y优秀,这时y就没用了。所以用一个单调栈维护即可。

\(O(n^2)\)代码

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<map>
#include<queue>
using namespace std;
typedef long long ll;
const int N=100000+100;
ll read() {
	ll x=0,f=1; char c=getchar();
	while(c<'0'||c>'9') {
		if(c=='-') f=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9') {
		x=x*10+c-'0';
		c=getchar();
	}
	return x*f;
}
int n;
int a[N],s[N],c[N];
ll f[N];
int main() {
	n=read();
	for(int i=1;i<=n;++i) {
		a[i]=read();
		s[i]=++c[a[i]];
	}
	for(int i=1;i<=n;++i)
		for(int j=1;j<=i;++j)
			if(a[i]==a[j])
				f[i]=max(f[i],f[j-1]+a[i]*(s[i]-s[j]+1)*(s[i]-s[j]+1));
	cout<<f[n];
	return 0;
}

\(O(n)\)代码

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<map>
#include<bits/stdc++.h>
#include<queue>
#include<vector>
using namespace std;
typedef long long ll;
const int N=100000+100;
ll read() {
	ll x=0,f=1; char c=getchar();
	while(c<'0'||c>'9') {
		if(c=='-') f=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9') {
		x=x*10+c-'0';
		c=getchar();
	}
	return x*f;
}
vector<int> sta[N];
int a[N],s[N],c[N];
ll f[N];
ll calc(int x,int y) {
	return f[x-1]+(ll)a[x]*y*y;
}
int n;
int find(int x,int y) {//寻找x比y优秀的最早时间 
	int l=1,r=n;
	int ans=n+1;
	while(l<=r) {
		int mid=l+r>>1;
		if(calc(x,mid-s[x]+1)>=calc(y,mid-s[y]+1)) ans=mid,r=mid-1;
		else l=mid+1;
	} 
	return ans;
}
int main() {
	n=read();
	for(int i=1;i<=n;++i) {
		a[i]=read();
		s[i]=++c[a[i]];
	}
	for(int i=1;i<=n;++i) {
		while(sta[a[i]].size()>=2&&find(sta[a[i]][sta[a[i]].size()-1],i)>=find(sta[a[i]][sta[a[i]].size()-2],sta[a[i]][sta[a[i]].size()-1]))
			sta[a[i]].pop_back();
		sta[a[i]].push_back(i);
		while(sta[a[i]].size()>=2&&find(sta[a[i]][sta[a[i]].size()-2],sta[a[i]][sta[a[i]].size()-1])<=s[i]) {
			sta[a[i]].pop_back();
		}
		int now=sta[a[i]].size();
		f[i]=calc(sta[a[i]][now-1],s[i]-s[sta[a[i]][now-1]]+1);
	}
	cout<<f[n];
}

posted @ 2018-10-08 10:40  wxyww  阅读(212)  评论(0编辑  收藏  举报