【CF1303G】Sum of Prefix Sums

题目

题目链接:https://codeforces.com/contest/1303/problem/G
有一颗 \(n\) 个节点的树,树每个节点有一个权值 \(a_i (1 \leq a_i \leq 10^6)\)
定义树上 \(u \rightarrow v\) 的链的权值如下:将 \(u\)\(v\) 的路径上点的权值依次排列在数组中,该数组的前缀和的和即这条路径的权值。
请求出权值最大的链,输出权值。
\(2 \leq n \leq 150000\)

思路

考虑两条链 \(x\to y,y\to z\),把 \(y\) 看作根,\(s_1\)\(x\to y\) 的权值,\(s_2\)\(y\to z\) 的权值(均不包含 \(y\) 点的权值),那么 \(x\to z\) 的权值即为

\[s_1+s_2+dep_{x}\sum_{p\in y\to z}a_p \]

因为这个东西我么记录一下三个前缀和就可以 \(O(1)\) 求出,所以可以考虑点分治,因为对于任意一条路径我们只需要枚举到其中一个点就可以了。
假设当前分到的根为 \(x\),记 \(sum[y][0/1/2]\) 表示 \(y\to x\) 的点的 \(a\) 之和,\(y\to x\) 的路径权值之和,\(x\to y\) 的路径权值之和。这个可以 dfs 一遍得到。
然后枚举 \(x\) 每一个儿子 \(y\),我们只需要在枚举过的儿子的子树中找到一个点 \(z\),使得

\[sum[y][1]+sum[z][2]+\text{dep}_z\cdot sum[z][0] \]

最大。
这个东西可以看作我们有一条斜率为 \(\text{dep}_z\) 的直线,需要在前面若干个 \((sum[z][0],sum[z][2])\) 中找到上凸壳的交点。这个直接上李超树就可以做到了。
注意需要正反枚举一遍,因为一条路径正着和反着的权值可能不同。
时间复杂度 \(O(n\log^2 n)\)

未曾设想的错误 :

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=300010,Inf=1e9;
int n,rt,tot,a[N],head[N],maxp[N],siz[N],dfn[N],rk[N];
ll ans,sum[N][3],dep[N];
bool vis[N];

struct edge
{
	int next,to;
}e[N*2];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

ll Calc(int i,ll k)
{
	return k*sum[i][0]+sum[i][2];
}

struct SegTree
{
	int ans[N*4];
	bool clr[N*4];
	
	void pushdown(int x)
	{
		if (clr[x])
			ans[x]=clr[x]=0,clr[x*2]=clr[x*2+1]=1;
	}
	
	void update(int x,int l,int r,int i)
	{
		pushdown(x);
		if (!ans[x] || (Calc(ans[x],l)<=Calc(i,l) && Calc(ans[x],r)<=Calc(i,r)))
			return (void)(ans[x]=i);
		if (Calc(ans[x],l)>=Calc(i,l) && Calc(ans[x],r)>=Calc(i,r))
			return;
		int mid=(l+r)>>1;
		if (Calc(ans[x],l)>=Calc(i,l))
		{
			if (Calc(ans[x],mid)<=Calc(i,mid))
				update(x*2,l,mid,ans[x]),ans[x]=i;
			else
				update(x*2+1,mid+1,r,i);
		}
		else
		{
			if (Calc(ans[x],mid)<=Calc(i,mid))
				update(x*2+1,mid+1,r,ans[x]),ans[x]=i;
			else
				update(x*2,l,mid,i);
		}
	}
	
	ll query(int x,int l,int r,ll k)
	{
		pushdown(x);
		if (l==r) return Calc(ans[x],k);
		int mid=(l+r)>>1; ll res=Calc(ans[x],k);
		if (k<=mid) return max(res,query(x*2,l,mid,k));
			else return max(res,query(x*2+1,mid+1,r,k));
	}
}seg;

void findrt(int x,int fa,int sum)
{
	siz[x]=1; maxp[x]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa && !vis[v])
		{
			findrt(v,x,sum);
			siz[x]+=siz[v];
			maxp[x]=max(maxp[x],siz[v]);
		}
	}
	maxp[x]=max(maxp[x],sum-siz[x]);
	if (maxp[x]<maxp[rt]) rt=x;
}

void dfs(int x,int fa)
{
	dfn[x]=++tot; rk[tot]=x;
	dep[x]=dep[fa]+1; siz[x]=1;
	if (x!=rt)
	{
		sum[x][0]=sum[fa][0]+a[x];
		sum[x][1]=a[x]+sum[fa][0]+sum[fa][1];
		sum[x][2]=sum[fa][2]+(dep[x]-1)*a[x];
	}
	else sum[x][0]=sum[x][1]=sum[x][2]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa && !vis[v])
		{
			dfs(v,x);
			siz[x]+=siz[v];
		}
	}
}

void calc(int x)
{
	tot=0; dfs(x,0);
	stack<int> st;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to; st.push(v);
		if (!vis[v])
		{
			for (int j=dfn[v];j<dfn[v]+siz[v];j++)
				ans=max(ans,seg.query(1,1,n,dep[rk[j]])+sum[rk[j]][1]+dep[rk[j]]*a[x]);
			for (int j=dfn[v];j<dfn[v]+siz[v];j++)
				seg.update(1,1,n,rk[j]);
		}
	}
	seg.clr[1]=1;
	while (st.size())
	{
		int v=st.top(); st.pop();
		if (!vis[v])
		{
			for (int j=dfn[v];j<dfn[v]+siz[v];j++)
			{
				ans=max(ans,seg.query(1,1,n,dep[rk[j]])+sum[rk[j]][1]+dep[rk[j]]*a[x]);
				ans=max(ans,sum[rk[j]][2]+sum[rk[j]][0]+a[x]);
			}
			for (int j=dfn[v];j<dfn[v]+siz[v];j++)
				seg.update(1,1,n,rk[j]);
		}
	}
	seg.clr[1]=1;
}

void solve(int x)
{
	calc(x); vis[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (!vis[v])
		{
			rt=0;
			findrt(v,x,siz[v]);
			solve(rt);
		}
	}
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d",&n);
	for (int i=1,x,y;i<n;i++)
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	for (int i=1;i<=n;i++)
	{
		scanf("%d",&a[i]);
		ans=max(ans,1LL*a[i]);
	}
	maxp[0]=Inf;
	findrt(1,0,n); solve(rt);
	printf("%lld",ans);
	return 0;
}
posted @ 2021-02-25 15:34  stoorz  阅读(67)  评论(0)    收藏  举报