题解:CF1303G Sum of Prefix Sums

CF1303G Sum of Prefix Sums

Solution

点分治,考虑将所求路径拆成两条,长度分别为 \(len,len'\),点权分别为 \(a,b\) 序列。

题目中要求前缀和的和,记 \(s_i\) 为前缀和:

\[\begin{aligned} S &= s_1 + (s_1 + s_2) + \dots + s_n \\ &= \sum\limits_{i=1}^{n}{(n - i + 1) \times s_i} \\ &= (n + 1) \sum\limits_{i=1}^{n}{s_i} - \sum\limits_{i=1}^{n}{s_i \times i} \end{aligned} \]

\(sum_A = \sum\limits_{i=1}^{len}{a_i},sum_B = \sum\limits_{i=1}^{len}{a_i \times i},sum_A' = \sum\limits_{i=1}^{len'}{b_i},sum_B'= \sum\limits_{i=1}^{len'}{b_i \times i}\),则所求路径权值根据上文所推式子:

\[(len + len' + 1) \times (sum_A + sum_A') - (sum_B + sum_B') - sum_A' \times len \\ \Downarrow \\ len' \times sum_A + (len + 1) \times sum_A - sumB \]

你会发现我在这里把只跟 \(len',sum_A',sum_B'\) 有关的踢掉了,因为点分治之后肯定是会遍历到的所以不用管。

然后你就能轻松看出这是一个一次函数,看来我们只需要在把式子放到李超树上维护即可,但是你要注意,\(u \to v\) 的路径权值可能和 \(v \to u\) 的路径权值不同,所以正反都跑一遍。

Code

#include <bits/stdc++.h>
#define int long long
using namespace std;
namespace FASTIO {
	inline int read() {
		int res = 0, f = 1;
		char ch = getchar();
		while (!isdigit(ch)) f = ch == '-' ? -1 : 1, ch = getchar();
		while (isdigit(ch)) res = (res << 1) + (res << 3) + (ch ^ 48), ch = getchar();
		return res * f;
	}
	inline void write (int x) {
		int st[33], top = 0;
		if (x < 0) x = -x, putchar ('-');
		do { st[top ++] = x % 10, x /= 10; } while(x);
		while(top) putchar (st[-- top] + '0');
		putchar ('\n');
	}
} using namespace FASTIO;
constexpr int MAXN = 1.5e5 + 10;
int n = 0, Val[MAXN], Ans = 0;
vector <int> edge[MAXN];
inline void Addedge (int u, int v) {
	edge[u].push_back (v);
	edge[v].push_back (u);
}

namespace LiChaoTree {
	struct Line {
		int k, b;
		Line (int k = 0, int b = -1e18):k(k), b(b){};
		int getL (int x) { return k * x + b; }
	};
	struct SMT {
		int l, r;
		Line ln;
		bool isCover;
		SMT (int l, int r, Line ln, bool isCover):l(l), r(r), ln(ln), isCover(isCover){};
		SMT (){};
	} Seg[MAXN << 2];
	#define lson (rt << 1)
	#define rson (rt << 1 | 1)
	void Build (int l, int r, int rt) {
		Seg[rt] = SMT (l, r, Line(), false);
		if (l == r)
			return;
		int mid = (l + r) >> 1;
		Build (l, mid, lson);
		Build (mid + 1, r, rson);
	}
	void Update (int rt, Line lnx) {
		int l = Seg[rt].l, r = Seg[rt].r;
		if (l == r) {
			if (Seg[rt].ln.getL(l) < lnx.getL(l))
				Seg[rt].ln = lnx;
			return;
		}
		int lpos = Seg[rt].ln.getL(l), rpos = Seg[rt].ln.getL(r),
		tmplpos = lnx.getL(l), tmprpos = lnx.getL(r);
		if (!Seg[rt].isCover)
			Seg[rt].ln = lnx, Seg[rt].isCover = true;
		if (tmplpos >= lpos && tmprpos >= rpos) {
			Seg[rt].ln = lnx;
		} else if (tmplpos >= lpos || tmprpos >= rpos) {
			int mid = (l + r) >> 1;
			if (Seg[rt].ln.getL(mid) < lnx.getL(mid)) swap (Seg[rt].ln, lnx);
			if (Seg[rt].ln.getL(l) < lnx.getL(l)) {
				Update (lson, lnx);
			} else {
				Update (rson, lnx);
			}
		}
	}
	void ClearTree (int rt) {
		int l = Seg[rt].l, r = Seg[rt].r;
		if (!Seg[rt].ln.k && Seg[rt].ln.b == -1e18) {
			return;
		} else {
			Seg[rt].ln = Line();
		}
		if (l == r)
			return;
		ClearTree (lson), ClearTree (rson);
	}
	int Query (int pos, int rt) {
		int l = Seg[rt].l, r = Seg[rt].r;
		if (l == r)
			return Seg[rt].ln.getL(pos);
		int mid = (l + r) >> 1, tmpVal = Seg[rt].ln.getL(pos);
		if (pos <= mid) {
			tmpVal = max (tmpVal, Query (pos, lson));
		} else {
			tmpVal = max (tmpVal, Query (pos, rson));
		}
		return tmpVal;
	}
	#undef lson
	#undef rson
} using namespace LiChaoTree;

namespace DivTree {
	int Siz[MAXN], F[MAXN], rt = 0, SumA[MAXN], _SumA[MAXN], SumB[MAXN], _SumB[MAXN], totSiz = 0;
	bool tag[MAXN];
	void findrt (int u, int fa) {
		Siz[u] = 1, F[u] = 0;
		for (int v : edge[u]) {
			if (v == fa || tag[v])
				continue;
			findrt (v, u);
			Siz[u] += Siz[v];
			F[u] = max (F[u], Siz[v]);
		}
		F[u] = max (F[u], totSiz - Siz[u]);
		if (!rt || F[u] < F[rt]) rt = u;
	}
	void getSum (int u, int fa, int _len) {
		for (int v : edge[u]) {
			if (v == fa || tag[v])
				continue;
			SumA[v] = SumA[u] + Val[v];
			SumB[v] = SumB[u] + SumA[v];
			_SumA[v] = _SumA[u] + Val[v];
			_SumB[v] = _SumB[u] + _len * Val[v];
			getSum (v, u, _len + 1);
		}
	}
	void dfsQuery (int u, int fa, int _len) {
		Ans = max (Ans, SumB[u]);
		Ans = max (Ans, _SumB[u] + _SumA[u] + Val[rt]);
		Ans = max (Ans, Query (_len, 1) + _SumA[u] * (_len + 1) - _SumB[u]);
		for (int v : edge[u]) {
			if (v == fa || tag[v])
				continue;
			dfsQuery (v, u, _len + 1);
		}
	}
	void dfsUpdate (int u, int fa, int len) {
		Update (1, Line (SumA[u], SumA[u] * (len + 1) - SumB[u]));
		for (int v : edge[u]) {
			if (v == fa || tag[v])
				continue;
			dfsUpdate (v, u, len + 1);
		}
	}
	void Solve (int u) {
		rt = u, tag[u] = true;
		SumA[u] = SumB[u] = Val[u], _SumA[u] = _SumB[u] = 0;
		Ans = max (Ans, Val[u]), getSum (u, 0, 1), ClearTree(1);
		for (int v : edge[u]) {
			if (tag[v])
				continue;
			dfsQuery (v, u, 1), dfsUpdate (v, u, 2);
		}
		reverse (edge[u].begin(), edge[u].end()), ClearTree(1);
		for (int v : edge[u]) {
			if (tag[v])
				continue;
			dfsQuery (v, u, 1), dfsUpdate (v, u, 2);
		}
		for (int v : edge[u]) {
			if (tag[v])
				continue;
			rt = 0, totSiz = Siz[v];
			findrt (v, u), Solve(rt);
		}
	}
} using namespace DivTree;

signed main() {
	n = read();
	for (int i = 1, u, v; i < n; i ++) {
		u = read(), v = read();
		Addedge (u, v);
	}
	for (int i = 1; i <= n; i ++) 
		Val[i] = read();
	rt = 0, totSiz = n, findrt (1, 0);
	Build (1, n, 1), Solve(rt);
	write (Ans);
	return 0;
}
posted @ 2025-02-21 11:23  xAlec  阅读(23)  评论(0)    收藏  举报