day8T1改错记

题目描述

定义\(dist(i, j)\)为树上\(i, j\)两点的距离

给出节点编号\(1\)\(n\)的树,一个\(1\)\(n\)的排列\(a\)\(q\)次询问,每次给出\(k\),求\(\sum_{l = 1}^{k} \sum_{r = l}^{k} \sum_{i = l}^{r} \sum_{j = i}^{r} dist(a_i, a_j) \ mod \ 998244353\)

\(n,q \le 1e5\)

解析

设询问\(k\)答案为\(f[k]\)

由于询问都是从\(1\)开始,一个自然的想法便是从\(f[k - 1]\)推向\(f[k]\)

考虑新加入\(a_k\)后答案的增量\(g[k]\)

我们先把\(dist(a_i, a_j)\)拆成\(dep(a_i) + dep(a_j) - 2 \cdot dep(lca(a_i, a_j))\)

加入\(k\)位置后增加的区间是以\(k\)位置结尾的区间,对每个\(i < k\),区间\([i, k]\)会求\(k - i + 1\)次与\(a_k\)有关的\(lca\),所以\(g[k]\)包含\(k \cdot (k - 1) / 2 \cdot dep[a_k]\),同时\(a_i\)会和\(a_k\)\(i\)\(dist\),所以加上\(dep[a_i] \cdot i\),然后还要加上上一次的增量\(g[k - 1]\),因为每个上次增量计算过的区间这一次也会多计算一次

还剩下的就是\(lca\)的部分,因为\(a_i\)会和\(a_k\)\(i\)\(dist\),所以减去的就是\(2 \sum_{i} i \cdot dep(lca(a_i, a_k))\)

这个东西据说是套路树链剖分然而蒟蒻我见都没见过qwq,具体做法是每插入一个点\(a_i\),把这个点到根的路径上每个点点权加\(i\),然后你就发现\(a_k\)到根的路径上的点权和就神奇地变成了这个东西……

总的来讲就是

\[g[k] = g[k - 1] + \frac{k \cdot (k - 1) \cdot dep(a_k)}{2} + \sum_{i < k} i \cdot dep(a_i) \\ - 2 \sum_{i < k} i \cdot dep(lca(a_k, a_i)) \]

然后\(f[i] = f[i - 1] + g[i]\),顺次推一遍就行了

复杂度\(O(n \log^2 n)\),因为有个树链剖分

代码

PS.先是树剖的时候没有把size统计到父亲T飞,再是线段树询问没有push_down结果WA完……我好菜啊qwq

#include <cstdio>
#include <cstring>
#include <iostream>
#define MAXN 100005
#define REG register

typedef long long LL;
const LL mod = 998244353ll;
struct Edge {
	int v, next;
	Edge(int _v = 0, int _n = 0):v(_v), next(_n) {}
} edge[MAXN << 1];
int head[MAXN], fa[MAXN], dep[MAXN], top[MAXN], dfn[MAXN], idx, heavy[MAXN], size[MAXN];
int N, Q, f[MAXN], upd1, upd2;
struct SegmentTree {
	int sum[MAXN << 2], add[MAXN << 2];
	void push_up(int);
	void push_down(int, int, int);
	void update(int, int, int, int, int, int);
	int query(int, int, int, int, int);
} tr;

inline void add_edge(int u, int v) { static int cnt; edge[cnt] = Edge(v, head[u]); head[u] = cnt++; }
inline void insert(int u, int v) { add_edge(u, v); add_edge(v, u); }
void dfs(int);
void dfs2(int);
inline void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline void dec(int &x, int y) { x -= y; if (x < 0) x += mod; }
inline int add(int x, int y) { int res = x + y; return res >= mod ? res - mod : res; }
inline int less(int x, int y) { int res = x - y; return res < 0 ? res + mod : res; }

int main() {
	freopen("sumsumsum.in", "r", stdin);
	freopen("sumsumsum.out", "w", stdout);
	memset(head, -1, sizeof head);
	scanf("%d%d", &N, &Q);
	for (int i = 1; i < N; ++i) {
		int u, v;
		scanf("%d%d", &u, &v);
		insert(u, v);
	}
	top[1] = dep[1] = 1;
	dfs(1);
	dfs2(1);
	for (int i = 1; i <= N; ++i) {
		int a; scanf("%d", &a);
		inc(upd2, add(upd1, (LL)i * (i - 1) / 2 * dep[a] % mod));
		inc(upd1,(LL)dep[a] * i % mod);
		int cur = a;
		while (cur) {
			int tp = top[cur];
			dec(upd2, tr.query(1, 1, N, dfn[tp], dfn[cur]) * 2 % mod);
			tr.update(1, 1, N, dfn[tp], dfn[cur], i);
			cur = fa[tp];
			//debug
			//printf("%d %d\n", upd1, upd2);
		}
		f[i] = add(f[i - 1], upd2);
	}
	while (Q--) {
		int k; scanf("%d", &k);
		printf("%d\n", f[k]);
	}
	return 0;
}
void dfs(int u) {
	dep[u] = dep[fa[u]] + 1;
	size[u] = 1;
	for (int i = head[u]; ~i; i = edge[i].next)
		if (edge[i].v ^ fa[u]) {
			fa[edge[i].v] = u;
			dfs(edge[i].v);
			size[u] += size[edge[i].v];
			if (!heavy[u] || size[edge[i].v] > size[heavy[u]]) heavy[u] = edge[i].v;
		}
}
void dfs2(int u) {
	dfn[u] = ++idx;
	if (heavy[u]) {
		top[heavy[u]] = top[u];
		dfs2(heavy[u]);
	}
	for (int i = head[u]; ~i; i = edge[i].next)
		if ((edge[i].v ^ fa[u]) && (edge[i].v ^ heavy[u])) { top[edge[i].v] = edge[i].v; dfs2(edge[i].v); }
}
void SegmentTree::push_down(int rt, int L, int R) {
	if (add[rt]) {
		int mid = (L + R) >> 1;
		(add[rt << 1] += add[rt]) %= mod;
		(add[rt << 1 | 1] += add[rt]) %= mod;
		sum[rt << 1] = (sum[rt << 1] + add[rt] * (LL)(mid - L + 1) % mod) % mod;
		sum[rt << 1 | 1] = (sum[rt << 1 | 1] + add[rt] * (LL)(R - mid) % mod) % mod;
		add[rt] = 0;
	}
}
inline void SegmentTree::push_up(int rt) {
	sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % mod;
}
void SegmentTree::update(int rt, int L, int R, int l, int r, int v) {
	if (L >= l && R <= r) {
		inc(add[rt], v);
		inc(sum[rt], v * (LL)(R - L + 1) % mod);
	} else {
		push_down(rt, L, R);
		int mid = (L + R) >> 1;
		if (l <= mid) update(rt << 1, L, mid, l, r, v);
		if (r > mid) update(rt << 1 | 1, mid + 1, R, l, r, v);
		push_up(rt);
	}
}
int SegmentTree::query(int rt, int L, int R, int l, int r) {
	if (L >= l && R <= r) return sum[rt];
	push_down(rt, L, R);
	int mid = (L + R) >> 1, res = 0;
	if (l <= mid) inc(res, query(rt << 1, L, mid, l, r));
	if (r > mid) inc(res, query(rt << 1 | 1, mid + 1, R, l, r));
	return res;
}
//Rhein_E
posted @ 2019-03-11 20:48  Rhein_E  阅读(150)  评论(0编辑  收藏  举报