题解:CF1303G Sum of Prefix Sums
Solution
点分治,考虑将所求路径拆成两条,长度分别为 \(len,len'\),点权分别为 \(a,b\) 序列。
题目中要求前缀和的和,记 \(s_i\) 为前缀和:
\[\begin{aligned}
S &= s_1 + (s_1 + s_2) + \dots + s_n
\\
&= \sum\limits_{i=1}^{n}{(n - i + 1) \times s_i}
\\
&= (n + 1) \sum\limits_{i=1}^{n}{s_i} - \sum\limits_{i=1}^{n}{s_i \times i}
\end{aligned}
\]
记 \(sum_A = \sum\limits_{i=1}^{len}{a_i},sum_B = \sum\limits_{i=1}^{len}{a_i \times i},sum_A' = \sum\limits_{i=1}^{len'}{b_i},sum_B'= \sum\limits_{i=1}^{len'}{b_i \times i}\),则所求路径权值根据上文所推式子:
\[(len + len' + 1) \times (sum_A + sum_A') - (sum_B + sum_B') - sum_A' \times len
\\ \Downarrow \\
len' \times sum_A + (len + 1) \times sum_A - sumB
\]
你会发现我在这里把只跟 \(len',sum_A',sum_B'\) 有关的踢掉了,因为点分治之后肯定是会遍历到的所以不用管。
然后你就能轻松看出这是一个一次函数,看来我们只需要在把式子放到李超树上维护即可,但是你要注意,\(u \to v\) 的路径权值可能和 \(v \to u\) 的路径权值不同,所以正反都跑一遍。
Code
#include <bits/stdc++.h>
#define int long long
using namespace std;
namespace FASTIO {
inline int read() {
int res = 0, f = 1;
char ch = getchar();
while (!isdigit(ch)) f = ch == '-' ? -1 : 1, ch = getchar();
while (isdigit(ch)) res = (res << 1) + (res << 3) + (ch ^ 48), ch = getchar();
return res * f;
}
inline void write (int x) {
int st[33], top = 0;
if (x < 0) x = -x, putchar ('-');
do { st[top ++] = x % 10, x /= 10; } while(x);
while(top) putchar (st[-- top] + '0');
putchar ('\n');
}
} using namespace FASTIO;
constexpr int MAXN = 1.5e5 + 10;
int n = 0, Val[MAXN], Ans = 0;
vector <int> edge[MAXN];
inline void Addedge (int u, int v) {
edge[u].push_back (v);
edge[v].push_back (u);
}
namespace LiChaoTree {
struct Line {
int k, b;
Line (int k = 0, int b = -1e18):k(k), b(b){};
int getL (int x) { return k * x + b; }
};
struct SMT {
int l, r;
Line ln;
bool isCover;
SMT (int l, int r, Line ln, bool isCover):l(l), r(r), ln(ln), isCover(isCover){};
SMT (){};
} Seg[MAXN << 2];
#define lson (rt << 1)
#define rson (rt << 1 | 1)
void Build (int l, int r, int rt) {
Seg[rt] = SMT (l, r, Line(), false);
if (l == r)
return;
int mid = (l + r) >> 1;
Build (l, mid, lson);
Build (mid + 1, r, rson);
}
void Update (int rt, Line lnx) {
int l = Seg[rt].l, r = Seg[rt].r;
if (l == r) {
if (Seg[rt].ln.getL(l) < lnx.getL(l))
Seg[rt].ln = lnx;
return;
}
int lpos = Seg[rt].ln.getL(l), rpos = Seg[rt].ln.getL(r),
tmplpos = lnx.getL(l), tmprpos = lnx.getL(r);
if (!Seg[rt].isCover)
Seg[rt].ln = lnx, Seg[rt].isCover = true;
if (tmplpos >= lpos && tmprpos >= rpos) {
Seg[rt].ln = lnx;
} else if (tmplpos >= lpos || tmprpos >= rpos) {
int mid = (l + r) >> 1;
if (Seg[rt].ln.getL(mid) < lnx.getL(mid)) swap (Seg[rt].ln, lnx);
if (Seg[rt].ln.getL(l) < lnx.getL(l)) {
Update (lson, lnx);
} else {
Update (rson, lnx);
}
}
}
void ClearTree (int rt) {
int l = Seg[rt].l, r = Seg[rt].r;
if (!Seg[rt].ln.k && Seg[rt].ln.b == -1e18) {
return;
} else {
Seg[rt].ln = Line();
}
if (l == r)
return;
ClearTree (lson), ClearTree (rson);
}
int Query (int pos, int rt) {
int l = Seg[rt].l, r = Seg[rt].r;
if (l == r)
return Seg[rt].ln.getL(pos);
int mid = (l + r) >> 1, tmpVal = Seg[rt].ln.getL(pos);
if (pos <= mid) {
tmpVal = max (tmpVal, Query (pos, lson));
} else {
tmpVal = max (tmpVal, Query (pos, rson));
}
return tmpVal;
}
#undef lson
#undef rson
} using namespace LiChaoTree;
namespace DivTree {
int Siz[MAXN], F[MAXN], rt = 0, SumA[MAXN], _SumA[MAXN], SumB[MAXN], _SumB[MAXN], totSiz = 0;
bool tag[MAXN];
void findrt (int u, int fa) {
Siz[u] = 1, F[u] = 0;
for (int v : edge[u]) {
if (v == fa || tag[v])
continue;
findrt (v, u);
Siz[u] += Siz[v];
F[u] = max (F[u], Siz[v]);
}
F[u] = max (F[u], totSiz - Siz[u]);
if (!rt || F[u] < F[rt]) rt = u;
}
void getSum (int u, int fa, int _len) {
for (int v : edge[u]) {
if (v == fa || tag[v])
continue;
SumA[v] = SumA[u] + Val[v];
SumB[v] = SumB[u] + SumA[v];
_SumA[v] = _SumA[u] + Val[v];
_SumB[v] = _SumB[u] + _len * Val[v];
getSum (v, u, _len + 1);
}
}
void dfsQuery (int u, int fa, int _len) {
Ans = max (Ans, SumB[u]);
Ans = max (Ans, _SumB[u] + _SumA[u] + Val[rt]);
Ans = max (Ans, Query (_len, 1) + _SumA[u] * (_len + 1) - _SumB[u]);
for (int v : edge[u]) {
if (v == fa || tag[v])
continue;
dfsQuery (v, u, _len + 1);
}
}
void dfsUpdate (int u, int fa, int len) {
Update (1, Line (SumA[u], SumA[u] * (len + 1) - SumB[u]));
for (int v : edge[u]) {
if (v == fa || tag[v])
continue;
dfsUpdate (v, u, len + 1);
}
}
void Solve (int u) {
rt = u, tag[u] = true;
SumA[u] = SumB[u] = Val[u], _SumA[u] = _SumB[u] = 0;
Ans = max (Ans, Val[u]), getSum (u, 0, 1), ClearTree(1);
for (int v : edge[u]) {
if (tag[v])
continue;
dfsQuery (v, u, 1), dfsUpdate (v, u, 2);
}
reverse (edge[u].begin(), edge[u].end()), ClearTree(1);
for (int v : edge[u]) {
if (tag[v])
continue;
dfsQuery (v, u, 1), dfsUpdate (v, u, 2);
}
for (int v : edge[u]) {
if (tag[v])
continue;
rt = 0, totSiz = Siz[v];
findrt (v, u), Solve(rt);
}
}
} using namespace DivTree;
signed main() {
n = read();
for (int i = 1, u, v; i < n; i ++) {
u = read(), v = read();
Addedge (u, v);
}
for (int i = 1; i <= n; i ++)
Val[i] = read();
rt = 0, totSiz = n, findrt (1, 0);
Build (1, n, 1), Solve(rt);
write (Ans);
return 0;
}

浙公网安备 33010602011771号