LOJ#2537. 「PKUWC2018」Minimax 线段树合并
$O(n^2)$ 的式子是好列的,然后我们发现这是一个关于前后缀的转移.
用线段树合并优化这一过程.
具体地,分别维护 $x,y$ 的后缀和.
这里要注意:由于这道题中两个不同子树肯定没有交集,所以在线段树合并的时候肯定会合并到一个点,使得两个树中一个为空.
然后由于另一个是空的,就没有合并的必要了,这样整个区间乘的就是一个相同的数了.
这样就只需要维护一个乘法标记就行了.
code:
#include <bits/stdc++.h>
#define ll long long
#define mod 998244353
#define N 300008
#define lson s[x].ls
#define rson s[x].rs
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int fa[N],cn,n,tot,ans;
int ch[N][2],val[N],perc[N],A[N],v[N];
int qpow(int x,int y)
{
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod)
if(y&1) tmp=(ll)tmp*x%mod;
return tmp;
}
inline int INV(int x) { return qpow(x,mod-2); }
struct data
{
int ls,rs;
ll sum,tag;
// data(){ ls=rs=sum=0,tag=1; }
}s[N*50];
int rt[N];
inline int newnode() { return ++tot; }
inline void pushup(int x) { s[x].sum=(ll)(s[lson].sum+s[rson].sum)%mod; }
void update(int &x,int l,int r,int p,int v)
{
if(!x) x=newnode(),s[x].tag=1;
if(l==r) { s[x].sum=s[x].tag=v; return; }
int mid=(l+r)>>1;
if(p<=mid) update(lson,l,mid,p,v);
else update(rson,mid+1,r,p,v);
pushup(x);
}
inline void mark(int x,ll v)
{
s[x].tag=(ll)s[x].tag*v%mod;
s[x].sum=(ll)s[x].sum*v%mod;
}
inline void pushdown(int x)
{
if(s[x].tag!=1)
{
if(lson) mark(lson,s[x].tag);
if(rson) mark(rson,s[x].tag);
s[x].tag=1;
}
}
// s1-> 小的
// s2-> 多的
int merge(int x,int y,ll det,ll x1,ll x2,ll y1,ll y2)
{
if(!x&&!y) return 0;
if(!x)
{
ll up=(ll)((ll)(1-det+mod)%mod*x2%mod+(ll)det*x1%mod)%mod;
mark(y,up);
return y;
}
if(!y)
{
ll up=(ll)((ll)(1-det+mod)%mod*y2%mod+(ll)det*y1%mod)%mod;
mark(x,up);
return x;
}
int now=newnode();
pushdown(x),pushdown(y);
int xr=(ll)(x2+s[s[x].rs].sum)%mod;
int yr=(ll)(y2+s[s[y].rs].sum)%mod;
int xl=(ll)(x1+s[s[x].ls].sum)%mod;
int yl=(ll)(y1+s[s[y].ls].sum)%mod;
s[now].tag=1;
s[now].ls=merge(s[x].ls,s[y].ls,det,x1,xr,y1,yr);
s[now].rs=merge(s[x].rs,s[y].rs,det,xl,x2,yl,y2);
pushup(now);
return now;
}
void dfs(int x)
{
int l=ch[x][0],r=ch[x][1];
if(!l) update(rt[x],1,cn,v[x],1);
else if(!r) dfs(l),rt[x]=rt[l];
else dfs(l),dfs(r),rt[x]=merge(rt[l],rt[r],1ll*perc[x],0,0,0,0);
}
void output(int x,int l,int r)
{
if(!x) return;
if(l==r)
{
(ans+=(ll)l*A[l]%mod*s[x].sum%mod*s[x].sum%mod)%=mod;
return;
}
int mid=(l+r)>>1;
pushdown(x);
output(s[x].ls,l,mid);
output(s[x].rs,mid+1,r);
}
int main()
{
// setIO("input");
scanf("%d",&n);
for(int i=1;i<=n;++i)
{
scanf("%d",&fa[i]);
if(ch[fa[i]][0]) ch[fa[i]][1]=i;
else ch[fa[i]][0]=i;
}
for(int i=1;i<=n;++i)
{
int a;
scanf("%d",&a);
if(!ch[i][0]) val[i]=a,A[++cn]=val[i];
else perc[i]=(ll)a*INV(10000)%mod;
}
sort(A+1,A+1+cn);
for(int i=1;i<=n;++i) if(!ch[i][0]) v[i]=lower_bound(A+1,A+1+cn,val[i])-A;
dfs(1),output(rt[1],1,cn),printf("%d\n",ans);
return 0;
}

浙公网安备 33010602011771号