Codeforces 380E Sereja and Dividing 题解 [ 紫 ] [ 线段树 ] [ 贪心 ] [ 数学 ]
Sereja and Dividing:一年前的模拟赛就能秒这个 *2600 了,可我现在怎么还这么菜 /ll/ll/ll。
先考虑当杯子的集合固定时如何选择,显然一个杯子不会被选第二次,并且杯子从大到小选一定是最优的。证明只需要列出最后答案的式子:\(ans = \dfrac{a_1 + \dfrac{a_2 + \dfrac{a_3 + \dfrac{\cdots}{2}}{2}}{2}}{2}\),就能发现,\(a_1\) 对答案的系数是 \(\dfrac{1}{2}\),\(a_2\) 对答案的系数为 \(\dfrac{1}{4}\),\(a_i\) 对答案的系数为 \(\dfrac{1}{2^i}\)。因此 \(a_1, a_2, \cdots , a_n\) 一定是降序排列的。
接下来考虑如何计算它。注意到每个数的贡献与它在区间内的排名有关,如果一个数的排名为 \(i\),那么它的贡献为 \(\dfrac{1}{2^i}\),于是从小到大把元素加入,然后直接上线段树维护已加入元素的贡献。具体地,一开始线段树上的值全都是 \(1\)。当加入 \(a_x\) 时,令 \(x\) 处的值为 \(\dfrac{1}{2}\)。然后根据乘法原理可知答案为:
\[\sum_{i = 1}^{x}(\prod_{j = i}^{x}val_j)\times\sum_{i = x}^{n}(\prod_{j = x}^{i}val_j)\times a_x
\]
因此线段树上维护区间内前缀 / 后缀积的和即可。时间复杂度 \(O(n\log n)\)。
代码是从模意义下答案的版本魔改过来的,所以 ll
类型的意思其实是 long double
,凑合着看吧。
#include <bits/stdc++.h>
#define fi first
#define se second
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long double ll;
typedef pair<int,int> pi;
const int N=300005;
struct node{
int l,r;
ll mul,lpsum,rpsum;
}tr[4*N];
void pushup(node &p,node ls,node rs)
{
p.mul=ls.mul*rs.mul;
p.lpsum=(ls.lpsum+ls.mul*rs.lpsum);
p.rpsum=(rs.rpsum+rs.mul*ls.rpsum);
}
void build(int p,int ln,int rn)
{
tr[p]={ln,rn,1,1,1};
if(ln==rn)return;
int mid=(ln+rn)>>1;
build(lc,ln,mid);
build(rc,mid+1,rn);
pushup(tr[p],tr[lc],tr[rc]);
}
void update(int p,int x)
{
if(tr[p].l==x&&tr[p].r==x)
{
tr[p].mul=tr[p].mul/2.0;
tr[p].lpsum=tr[p].mul;
tr[p].rpsum=tr[p].mul;
return;
}
int mid=(tr[p].l+tr[p].r)>>1;
if(x<=mid)update(lc,x);
else update(rc,x);
pushup(tr[p],tr[lc],tr[rc]);
}
node query(int p,int ln,int rn)
{
if(ln<=tr[p].l&&tr[p].r<=rn)
{
return tr[p];
}
int mid=(tr[p].l+tr[p].r)>>1;
if(rn<=mid)return query(lc,ln,rn);
if(ln>=mid+1)return query(rc,ln,rn);
node tmp;
pushup(tmp,query(lc,ln,rn),query(rc,ln,rn));
return tmp;
}
int n;
ll a[N],ans=0;
pi b[N];
bool cmp(pi x,pi y)
{
return x>y;
}
ll con(ll p)
{
node rnd=query(1,p,n);
ll rans=rnd.lpsum;
node lnd=query(1,1,p);
ll lans=lnd.rpsum;
return (lans*rans);
}
int main()
{
scanf("%d", &n);
for(int i=1;i<=n;i++)
{
scanf("%Lf", &a[i]);
b[i]={a[i],i};
}
sort(b+1,b+n+1,cmp);
build(1,1,n);
for(int i=1;i<=n;i++)
{
ans=(ans+((con(b[i].se) / 2.0)*b[i].fi));
update(1,b[i].se);
}
printf("%.15Lf", ans / (ll(n) * ll(n)));
return 0;
}