day8T1改错记
题目描述
定义\(dist(i, j)\)为树上\(i, j\)两点的距离
给出节点编号\(1\)到\(n\)的树,一个\(1\)到\(n\)的排列\(a\),\(q\)次询问,每次给出\(k\),求\(\sum_{l = 1}^{k} \sum_{r = l}^{k} \sum_{i = l}^{r} \sum_{j = i}^{r} dist(a_i, a_j) \ mod \ 998244353\)
\(n,q \le 1e5\)
解析
设询问\(k\)答案为\(f[k]\)
由于询问都是从\(1\)开始,一个自然的想法便是从\(f[k - 1]\)推向\(f[k]\)
考虑新加入\(a_k\)后答案的增量\(g[k]\)
我们先把\(dist(a_i, a_j)\)拆成\(dep(a_i) + dep(a_j) - 2 \cdot dep(lca(a_i, a_j))\)
加入\(k\)位置后增加的区间是以\(k\)位置结尾的区间,对每个\(i < k\),区间\([i, k]\)会求\(k - i + 1\)次与\(a_k\)有关的\(lca\),所以\(g[k]\)包含\(k \cdot (k - 1) / 2 \cdot dep[a_k]\),同时\(a_i\)会和\(a_k\)求\(i\)次\(dist\),所以加上\(dep[a_i] \cdot i\),然后还要加上上一次的增量\(g[k - 1]\),因为每个上次增量计算过的区间这一次也会多计算一次
还剩下的就是\(lca\)的部分,因为\(a_i\)会和\(a_k\)求\(i\)次\(dist\),所以减去的就是\(2 \sum_{i} i \cdot dep(lca(a_i, a_k))\)
这个东西据说是套路树链剖分然而蒟蒻我见都没见过qwq,具体做法是每插入一个点\(a_i\),把这个点到根的路径上每个点点权加\(i\),然后你就发现\(a_k\)到根的路径上的点权和就神奇地变成了这个东西……
总的来讲就是
然后\(f[i] = f[i - 1] + g[i]\),顺次推一遍就行了
复杂度\(O(n \log^2 n)\),因为有个树链剖分
代码
PS.先是树剖的时候没有把size统计到父亲T飞,再是线段树询问没有push_down结果WA完……我好菜啊qwq
#include <cstdio>
#include <cstring>
#include <iostream>
#define MAXN 100005
#define REG register
typedef long long LL;
const LL mod = 998244353ll;
struct Edge {
int v, next;
Edge(int _v = 0, int _n = 0):v(_v), next(_n) {}
} edge[MAXN << 1];
int head[MAXN], fa[MAXN], dep[MAXN], top[MAXN], dfn[MAXN], idx, heavy[MAXN], size[MAXN];
int N, Q, f[MAXN], upd1, upd2;
struct SegmentTree {
int sum[MAXN << 2], add[MAXN << 2];
void push_up(int);
void push_down(int, int, int);
void update(int, int, int, int, int, int);
int query(int, int, int, int, int);
} tr;
inline void add_edge(int u, int v) { static int cnt; edge[cnt] = Edge(v, head[u]); head[u] = cnt++; }
inline void insert(int u, int v) { add_edge(u, v); add_edge(v, u); }
void dfs(int);
void dfs2(int);
inline void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline void dec(int &x, int y) { x -= y; if (x < 0) x += mod; }
inline int add(int x, int y) { int res = x + y; return res >= mod ? res - mod : res; }
inline int less(int x, int y) { int res = x - y; return res < 0 ? res + mod : res; }
int main() {
freopen("sumsumsum.in", "r", stdin);
freopen("sumsumsum.out", "w", stdout);
memset(head, -1, sizeof head);
scanf("%d%d", &N, &Q);
for (int i = 1; i < N; ++i) {
int u, v;
scanf("%d%d", &u, &v);
insert(u, v);
}
top[1] = dep[1] = 1;
dfs(1);
dfs2(1);
for (int i = 1; i <= N; ++i) {
int a; scanf("%d", &a);
inc(upd2, add(upd1, (LL)i * (i - 1) / 2 * dep[a] % mod));
inc(upd1,(LL)dep[a] * i % mod);
int cur = a;
while (cur) {
int tp = top[cur];
dec(upd2, tr.query(1, 1, N, dfn[tp], dfn[cur]) * 2 % mod);
tr.update(1, 1, N, dfn[tp], dfn[cur], i);
cur = fa[tp];
//debug
//printf("%d %d\n", upd1, upd2);
}
f[i] = add(f[i - 1], upd2);
}
while (Q--) {
int k; scanf("%d", &k);
printf("%d\n", f[k]);
}
return 0;
}
void dfs(int u) {
dep[u] = dep[fa[u]] + 1;
size[u] = 1;
for (int i = head[u]; ~i; i = edge[i].next)
if (edge[i].v ^ fa[u]) {
fa[edge[i].v] = u;
dfs(edge[i].v);
size[u] += size[edge[i].v];
if (!heavy[u] || size[edge[i].v] > size[heavy[u]]) heavy[u] = edge[i].v;
}
}
void dfs2(int u) {
dfn[u] = ++idx;
if (heavy[u]) {
top[heavy[u]] = top[u];
dfs2(heavy[u]);
}
for (int i = head[u]; ~i; i = edge[i].next)
if ((edge[i].v ^ fa[u]) && (edge[i].v ^ heavy[u])) { top[edge[i].v] = edge[i].v; dfs2(edge[i].v); }
}
void SegmentTree::push_down(int rt, int L, int R) {
if (add[rt]) {
int mid = (L + R) >> 1;
(add[rt << 1] += add[rt]) %= mod;
(add[rt << 1 | 1] += add[rt]) %= mod;
sum[rt << 1] = (sum[rt << 1] + add[rt] * (LL)(mid - L + 1) % mod) % mod;
sum[rt << 1 | 1] = (sum[rt << 1 | 1] + add[rt] * (LL)(R - mid) % mod) % mod;
add[rt] = 0;
}
}
inline void SegmentTree::push_up(int rt) {
sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % mod;
}
void SegmentTree::update(int rt, int L, int R, int l, int r, int v) {
if (L >= l && R <= r) {
inc(add[rt], v);
inc(sum[rt], v * (LL)(R - L + 1) % mod);
} else {
push_down(rt, L, R);
int mid = (L + R) >> 1;
if (l <= mid) update(rt << 1, L, mid, l, r, v);
if (r > mid) update(rt << 1 | 1, mid + 1, R, l, r, v);
push_up(rt);
}
}
int SegmentTree::query(int rt, int L, int R, int l, int r) {
if (L >= l && R <= r) return sum[rt];
push_down(rt, L, R);
int mid = (L + R) >> 1, res = 0;
if (l <= mid) inc(res, query(rt << 1, L, mid, l, r));
if (r > mid) inc(res, query(rt << 1 | 1, mid + 1, R, l, r));
return res;
}
//Rhein_E