【CF1303G】Sum of Prefix Sums
题目
题目链接:https://codeforces.com/contest/1303/problem/G
有一颗 \(n\) 个节点的树,树每个节点有一个权值 \(a_i (1 \leq a_i \leq 10^6)\)。
定义树上 \(u \rightarrow v\) 的链的权值如下:将 \(u\) 到 \(v\) 的路径上点的权值依次排列在数组中,该数组的前缀和的和即这条路径的权值。
请求出权值最大的链,输出权值。
\(2 \leq n \leq 150000\)。
思路
考虑两条链 \(x\to y,y\to z\),把 \(y\) 看作根,\(s_1\) 为 \(x\to y\) 的权值,\(s_2\) 为 \(y\to z\) 的权值(均不包含 \(y\) 点的权值),那么 \(x\to z\) 的权值即为
因为这个东西我么记录一下三个前缀和就可以 \(O(1)\) 求出,所以可以考虑点分治,因为对于任意一条路径我们只需要枚举到其中一个点就可以了。
假设当前分到的根为 \(x\),记 \(sum[y][0/1/2]\) 表示 \(y\to x\) 的点的 \(a\) 之和,\(y\to x\) 的路径权值之和,\(x\to y\) 的路径权值之和。这个可以 dfs 一遍得到。
然后枚举 \(x\) 每一个儿子 \(y\),我们只需要在枚举过的儿子的子树中找到一个点 \(z\),使得
最大。
这个东西可以看作我们有一条斜率为 \(\text{dep}_z\) 的直线,需要在前面若干个 \((sum[z][0],sum[z][2])\) 中找到上凸壳的交点。这个直接上李超树就可以做到了。
注意需要正反枚举一遍,因为一条路径正着和反着的权值可能不同。
时间复杂度 \(O(n\log^2 n)\)。
未曾设想的错误 :

代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=300010,Inf=1e9;
int n,rt,tot,a[N],head[N],maxp[N],siz[N],dfn[N],rk[N];
ll ans,sum[N][3],dep[N];
bool vis[N];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
ll Calc(int i,ll k)
{
return k*sum[i][0]+sum[i][2];
}
struct SegTree
{
int ans[N*4];
bool clr[N*4];
void pushdown(int x)
{
if (clr[x])
ans[x]=clr[x]=0,clr[x*2]=clr[x*2+1]=1;
}
void update(int x,int l,int r,int i)
{
pushdown(x);
if (!ans[x] || (Calc(ans[x],l)<=Calc(i,l) && Calc(ans[x],r)<=Calc(i,r)))
return (void)(ans[x]=i);
if (Calc(ans[x],l)>=Calc(i,l) && Calc(ans[x],r)>=Calc(i,r))
return;
int mid=(l+r)>>1;
if (Calc(ans[x],l)>=Calc(i,l))
{
if (Calc(ans[x],mid)<=Calc(i,mid))
update(x*2,l,mid,ans[x]),ans[x]=i;
else
update(x*2+1,mid+1,r,i);
}
else
{
if (Calc(ans[x],mid)<=Calc(i,mid))
update(x*2+1,mid+1,r,ans[x]),ans[x]=i;
else
update(x*2,l,mid,i);
}
}
ll query(int x,int l,int r,ll k)
{
pushdown(x);
if (l==r) return Calc(ans[x],k);
int mid=(l+r)>>1; ll res=Calc(ans[x],k);
if (k<=mid) return max(res,query(x*2,l,mid,k));
else return max(res,query(x*2+1,mid+1,r,k));
}
}seg;
void findrt(int x,int fa,int sum)
{
siz[x]=1; maxp[x]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa && !vis[v])
{
findrt(v,x,sum);
siz[x]+=siz[v];
maxp[x]=max(maxp[x],siz[v]);
}
}
maxp[x]=max(maxp[x],sum-siz[x]);
if (maxp[x]<maxp[rt]) rt=x;
}
void dfs(int x,int fa)
{
dfn[x]=++tot; rk[tot]=x;
dep[x]=dep[fa]+1; siz[x]=1;
if (x!=rt)
{
sum[x][0]=sum[fa][0]+a[x];
sum[x][1]=a[x]+sum[fa][0]+sum[fa][1];
sum[x][2]=sum[fa][2]+(dep[x]-1)*a[x];
}
else sum[x][0]=sum[x][1]=sum[x][2]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa && !vis[v])
{
dfs(v,x);
siz[x]+=siz[v];
}
}
}
void calc(int x)
{
tot=0; dfs(x,0);
stack<int> st;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to; st.push(v);
if (!vis[v])
{
for (int j=dfn[v];j<dfn[v]+siz[v];j++)
ans=max(ans,seg.query(1,1,n,dep[rk[j]])+sum[rk[j]][1]+dep[rk[j]]*a[x]);
for (int j=dfn[v];j<dfn[v]+siz[v];j++)
seg.update(1,1,n,rk[j]);
}
}
seg.clr[1]=1;
while (st.size())
{
int v=st.top(); st.pop();
if (!vis[v])
{
for (int j=dfn[v];j<dfn[v]+siz[v];j++)
{
ans=max(ans,seg.query(1,1,n,dep[rk[j]])+sum[rk[j]][1]+dep[rk[j]]*a[x]);
ans=max(ans,sum[rk[j]][2]+sum[rk[j]][0]+a[x]);
}
for (int j=dfn[v];j<dfn[v]+siz[v];j++)
seg.update(1,1,n,rk[j]);
}
}
seg.clr[1]=1;
}
void solve(int x)
{
calc(x); vis[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v])
{
rt=0;
findrt(v,x,siz[v]);
solve(rt);
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
for (int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
ans=max(ans,1LL*a[i]);
}
maxp[0]=Inf;
findrt(1,0,n); solve(rt);
printf("%lld",ans);
return 0;
}

浙公网安备 33010602011771号