Loading

动态dp 学习笔记

前言

被迫营业*2,不过去仔细学一下也挺好的。为了营业去学了好多新东西(((

由于本人水平有限,如有不严谨的地方还请指出。

ddp 主要用来处理树上dp问题,有时候出题人比较恶心带上修改,ddp就是用来支持快速修改的。

使用的前提是转移比较简洁,可以写成矩阵,基本上都是单点修改。

我感觉这个算法还是看题比较好理解。

例题一

这里会介绍三种常见的ddp维护方法。

P4719 【模板】"动态 DP"&动态树分治

首先考虑不带修改的情况。

\(f(u, 0 / 1)\) 分别表示:强制选择 \(u\) 这个点,强制不选择 \(u\) 这个点 时,以 \(u\) 为根的子树的最大独立集。

这时候以 \(u\) 为根的子树的答案就是 \(\max(f(u,0),f(u,1))\)

有转移

\[f(u,0) = \sum_{v\in son(u)} \max(f(v, 0),f(v, 1))\\ f(u,1) = a_u + \sum_{v\in son(u)} f(v, 0) \]

接下去引入 ddp 的思想。

考虑将重儿子和轻儿子分开考虑。

\(g(u, 0 / 1)\) 表示,只考虑 \(u\) 的轻儿子,强制选 \(u\) 或不选 \(u\) 这个点时,以 \(u\) 为根的子树的最大独立集。

\(wson(u)\)\(u\) 的重儿子。

可以得到

\[g(u, 0) = \sum_{v\in son(u), v\not = wson(u)} \max(f(v,0),f(v,1))\\ g(u, 1) = a_u + \sum_{v\in son(u), v\not = wson(u)} f(v,0)\\ f(u, 0) = g(u, 0) + \max(f(son_u, 0), f(son_u, 1))\\ f(u, 1) = g(u, 1) + f(son_u, 0) \]

注意转移方程中,把 \(a_u\) 的贡献加到了 \(g(u,1)\) 中,不然第四条多个 \(a_u\) 方程不够简洁,不方便写成矩阵。

\(f\) 的转移写成矩阵:

\[\begin{bmatrix} g(u,0)&g(u,0)\\ g(u,1)&-\infty \end{bmatrix} * \begin{bmatrix} f(son_u,0)\\ f(son_u,1) \end{bmatrix} = \begin{bmatrix} f(u,0)\\ f(u,1) \end{bmatrix} \]

注意上面的 \(*\) 是广义矩阵乘法: \(a_{i,j}=\max\{b_{i,k}+c_{k,j}\}\),这个运算符也是有结合律的。

一开始的时候我们先一趟dp求出 \(f,g\)

这时候对于点 \(u\) 为根的子树查询答案会非常方便:设 \(u\) 所在重链底端是 \(End(u)\),我们把 \(u\)\(End(u)\) 这段区间的矩阵全部按顺序乘起来就好了。

可以脑补一下这个过程:重链底端是个叶子,然后不断加入重链周围的轻子树以及重儿子,拼凑成了整颗子树。

考虑如何带上修改。

假设修改了点 \(u\)

直接影响到的是 \(f(u,1),g(u,1)\)

接着可以想象,往祖先走的时候同重链的 \(f\) 都能被矩阵直接更新,那么影响到的就是所有轻边的转移。

于是考虑计算更新之后对于轻边父亲的 \(g\) 的贡献。

发现 \(g\) 的转移和 \(f\) 有关,并且我们可以知道这条重链的 \(f\) 以及更新前重链的 \(f\),那么把之前的贡献减掉,把现在的贡献加上,就更新完毕了。

快速查询两点间矩阵乘积可以通过 树链剖分+线段树维护区间矩阵乘积 来维护,而且修改也可以通过跳轻边很方便地维护。

总共会跳到 \(O(\log n)\) 条轻边,还有每跳一次线段树修改的 \(O(\log n)\),总复杂度是 \(O(n\log^2 n)\)

实现的时候建议这种小型矩阵手动展开,常数上可以减小好多。

远古代码。但是好像不是特别丑

Code
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double db;
#define pb(x) push_back(x)
#define mkp(x,y) make_pair(x,y)
inline int read() {
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
	return x*f;
}
const int N=100005;
const int M=N<<2;
const int inf=1e8;
int n,m,a[N];
int siz[N],dfn[N],tmr,son[N],fa[N],top[N],rev[N],ed[N],f[N][2];
struct edge{
	int nxt,to;
}e[N<<1];
int head[N],num_edge;
void addedge(int fr,int to){
	++num_edge;
	e[num_edge].nxt=head[fr];
	e[num_edge].to=to;
	head[fr]=num_edge;
}
struct Matrix{
	int a[2][2];
	Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
	int*operator[](const int&k){return a[k];}
	Matrix operator * (const Matrix&b){
		Matrix res;
//		for(int i=0;i<2;++i)
//			for(int j=0;j<2;++j)
//				for(int k=0;k<2;++k)
//					res.a[i][j]=max(res.a[i][j],a[i][k]+b.a[k][j]);
		res[0][0]=max(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
		res[0][1]=max(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
		res[1][0]=max(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
		res[1][1]=max(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
		return res;
	}
}mat[N],val[M];
void dfs1(int u,int ft){
	siz[u]=1,f[u][1]=a[u];
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==ft)continue;
		fa[v]=u,dfs1(v,u),siz[u]+=siz[v];
		if(siz[v]>siz[son[u]])son[u]=v;
		f[u][0]+=max(f[v][0],f[v][1]);
		f[u][1]+=f[v][0];
	}
}
void dfs2(int u,int tp){
	top[u]=tp,dfn[u]=++tmr,rev[tmr]=u;
	if(son[u])dfs2(son[u],tp),ed[u]=ed[son[u]];
	else ed[u]=u;
	int g[2];g[0]=0,g[1]=a[u];
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==son[u]||v==fa[u])continue;
		dfs2(v,v);
		g[0]+=max(f[v][0],f[v][1]);
		g[1]+=f[v][0];
	}
	mat[u][0][0]=g[0],mat[u][0][1]=g[0];
	mat[u][1][0]=g[1],mat[u][1][1]=-inf;
}
#define lc (p<<1)
#define rc (p<<1|1)
void pushup(int p){val[p]=val[lc]*val[rc];}
void build(int l,int r,int p){
	if(l==r)return val[p]=mat[rev[l]],void();
	int mid=(l+r)>>1;
	build(l,mid,lc),build(mid+1,r,rc);
	pushup(p);
}
Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
	if(ql<=l&&r<=qr)return val[p];
	int mid=(l+r)>>1;
	if(qr<=mid)return query(ql,qr,l,mid,lc);
	if(mid<ql)return query(ql,qr,mid+1,r,rc);
	return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
}
void change(int pos,int l=1,int r=n,int p=1){
	if(l==r)return val[p]=mat[rev[l]],void();
	int mid=(l+r)>>1;
	if(pos<=mid)change(pos,l,mid,lc);
	else change(pos,mid+1,r,rc);
	pushup(p);
}
void update(int x,int v){
	mat[x][1][0]+=v-a[x],a[x]=v;
	while(x){
		Matrix lst=query(dfn[top[x]],dfn[ed[x]]);
		change(dfn[x]);
		Matrix now=query(dfn[top[x]],dfn[ed[x]]);
		x=fa[top[x]];
		mat[x][0][0]+=max(now[0][0],now[1][0])-max(lst[0][0],lst[1][0]);
		mat[x][0][1]=mat[x][0][0];
		mat[x][1][0]+=now[0][0]-lst[0][0];
	}
}
signed main(){
	n=read(),m=read();
	for(int i=1;i<=n;++i)a[i]=read();
	for(int i=1;i<n;++i){
		int x=read(),y=read();
		addedge(x,y),addedge(y,x);
	}
	dfs1(1,0),dfs2(1,1),build(1,n,1);
	while(m--){
		int x=read(),v=read();
		update(x,v);
		Matrix t=query(dfn[1],dfn[ed[1]]);
		printf("%d\n",max(t[0][0],t[1][0]));
	}
	return 0;
}

树剖+线段树的复杂度是两只 log,这使得人们思考有没有更快的方法。

P4751 【模板】"动态DP"&动态树分治(加强版)

可以发现上面那种方法的在做的其实就是维护链上矩阵积,维护链上信息使我们想到了 \(O(n\log n)\) 的 LCT。

考虑只维护实边信息,虚儿子信息在 access 的时候更新上去。

更新一个节点信息的时候可以先 access 再 splay,这时候修改它对于任何节点都是没有影响的,可以直接修改。

查询一个节点的信息会有点特殊,需要执行的操作是:\(access(fa_x),splay(x)\)。因为把 \(fa_x\) 以上的节点移到别的 splay 里面去,\(splay(x)\)\(x\) 下面挂的节点才是 \(x\) 的子树内的节点。

这里附上一份 LCT 实现,顺便去学了一下。

成功把复杂度降掉一只 \(\log\),LCT 的常数非常大就是了。

说句闲话,这题貌似正常常数的 LCT 都能过去(

Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp make_pair
#define pb push_back
#define sz(v) (int)(v).size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}
const int N = 1000005;
const int inf = 0x3f3f3f3f;
int n, m, a[N], lastans;
vector<int> e[N];
struct Matrix {
	int a[2][2];
	Matrix(){ memset(a, -0x3f, sizeof a); }
	inline int* operator [](const int &k) { return a[k]; }
	inline Matrix operator * (const Matrix &t) const {
		Matrix res;
		res.a[0][0] = max(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
		res.a[0][1] = max(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
		res.a[1][0] = max(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
		res.a[1][1] = max(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
		return res;
	}
};
int fa[N], ch[N][2], dp[N][2];
Matrix val[N], sum[N];
inline bool nroot(int x) { return ch[fa[x]][0] == x || ch[fa[x]][1] == x; }
inline void pushup(int x) {
	sum[x] = val[x];
	if(ch[x][0]) sum[x] = sum[ch[x][0]] * sum[x];
	if(ch[x][1]) sum[x] = sum[x] * sum[ch[x][1]];
}
inline void rotate(int x) {
	int y = fa[x], z = fa[y], k = ch[y][1] == x, w = ch[x][!k];
	if(nroot(y)) ch[z][ch[z][1] == y] = x;
	ch[x][!k] = y, ch[y][k] = w;
	fa[w] = y, fa[y] = x, fa[x] = z;
	pushup(y);
}
inline void splay(int x) {
	while(nroot(x)) {
		int y = fa[x], z = fa[y];
		if(nroot(y)) rotate((ch[z][1] == y) ^ (ch[y][1] == x) ? x : y);
		rotate(x);
	}
	pushup(x);
}
inline void access(int x) {
	for(int y = 0; x; x = fa[y = x]) {
		splay(x);
		if(y) {
			val[x][0][0] -= max(sum[y][0][0], sum[y][1][0]);
			val[x][0][1] = val[x][0][0];
			val[x][1][0] -= sum[y][0][0];
		}
		if(ch[x][1]) {
			int t = ch[x][1];
			val[x][0][0] += max(sum[t][0][0], sum[t][1][0]);
			val[x][0][1] = val[x][0][0];
			val[x][1][0] += sum[t][0][0];
		}
		ch[x][1] = y, pushup(x);
	}
}
void dfs(int u, int ft) {
	dp[u][1] = a[u], fa[u] = ft;
	for(int v : e[u]) if(v != ft) {
		dfs(v, u);
		dp[u][0] += max(dp[v][0], dp[v][1]);
		dp[u][1] += dp[v][0];
	}
	val[u][0][0] = val[u][0][1] = dp[u][0];
	val[u][1][0] = dp[u][1], val[u][1][1] = -inf;
	sum[u] = val[u];
}
signed main() {
	n = read(), m = read();
	rep(i, 1, n) a[i] = read();
	rep(i, 2, n) {
		int x = read(), y = read();
		e[x].pb(y), e[y].pb(x);
	}
	dfs(1, 0);
	while(m--) {
		int x = read() ^ lastans, y = read();
		access(x), splay(x);
		val[x][1][0] += y - a[x], a[x] = y, pushup(x);
		splay(1);
		printf("%d\n", lastans = max(sum[1][0][0], sum[1][1][0]));
	}
	return 0;
}

考虑到这棵树并不会动,用 LCT 维护有点浪费,于是考虑搞一种新的方法来划分树。

有人从上古论文里翻出来了一个科技叫做“全局平衡二叉树”。

注意到 LCT 就是每一条链建平衡树,考虑用类似的思想。

建立的方法就是:先树剖,对于每一条重链每次取带权重心建立二叉树,连实边。对于轻子树建立的二叉树的根往当前二叉树上的节点拉虚边。

容易发现在同一颗二叉树内往父亲跳的时候,每跳一次子树大小都会至少翻倍;切换一次二叉树意味着切换一次重边,只会有 \(O(\log n)\) 次。仔细想想,这两个 \(\log\) 并不是乘起来的,是加起来的,因为子树大小翻倍至多 \(\log n\) 次,跳轻链也至多 \(\log n\) 次。所以树高是 \(O(\log n)\) 级别的,粗略分析上限是 \(2\log n\),注意有个常数。

仍然采用矩阵维护,维护方法类似 LCT,只维护实边信息,虚边一路跳到根更新 \(g\)

如果我们要查询某个子树的答案怎么办?

首先找到这个节点在全局平衡二叉树上所在的二叉树。

考虑到全局平衡二叉树上每一个由实边连接的二叉树都是一条重链,并且先序遍历就是这条重链,根据之前重链剖分时的思路,我们要求的是一个点到重链底端的矩阵积。

就相当于我们在二叉排序树上查询序列后缀积,这个随便写写就好了。

Code
#include <bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long LL;
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define rep(i, x, y) for(int i = x, i##end = y; i <= i##end; ++i)
#define per(i, x, y) for(int i = x, i##end = y; i >= i##end; --i)
template<class T> inline bool ckmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
template<class T> inline bool ckmin(T &x, T y) { return x > y ? x = y, 1 : 0; }
inline int read() {
    int x = 0, f = 1; char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = 0; ch = getchar(); }
    while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
    return f ? x : -x;
}

const int N = 1000005;
const int inf = 0x3f3f3f3f;
int n, m, a[N], f[N][2], g[N][2];
vector<int> e[N];
namespace Tree {

int siz[N], fa[N], son[N];
void dfs(int u, int ft) {
	siz[u] = 1, fa[u] = ft;
	f[u][1] = a[u];
	for(int v : e[u]) if(v != ft) {
		dfs(v, u), siz[u] += siz[v];
		if(siz[v] > siz[son[u]]) son[u] = v;
		f[u][0] += max(f[v][0], f[v][1]);
		f[u][1] += f[v][0];
	}
	g[u][0] = f[u][0] - max(f[son[u]][0], f[son[u]][1]);
	g[u][1] = f[u][1] - f[son[u]][0];
}

}

struct Matrix {
	int a[2][2];
	Matrix(){ memset(a, -0x3f, sizeof a); }
	inline int* operator [](const int &k) { return a[k]; }
	inline Matrix operator * (const Matrix &t) const {
		Matrix res;
		res.a[0][0] = max(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
		res.a[0][1] = max(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
		res.a[1][0] = max(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
		res.a[1][1] = max(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
		return res;
	}
	void print() {
		cerr << a[0][0] << ' ' << a[0][1] << '\n' << a[1][0] << ' ' << a[1][1] << '\n';
	}
};

namespace bst {
int fa[N], ch[N][2], stk[N], top, tsz[N], rt;
bool isrt[N];
Matrix val[N], sum[N];
inline void pushup(int u) {
	sum[u] = val[u];
	if(ch[u][0]) sum[u] = sum[ch[u][0]] * sum[u];
	if(ch[u][1]) sum[u] = sum[u] * sum[ch[u][1]];
}
inline int build2(int l, int r) {
	if(l > r) return 0;
	int ALL = 0, now = 0;
	rep(i, l, r) ALL += tsz[i];
	rep(i, l, r) {
		now += tsz[i];
		if(now << 1 >= ALL) {
			int u = stk[i];
			fa[ch[u][0] = build2(l, i - 1)] = u;
			fa[ch[u][1] = build2(i + 1, r)] = u;
			return pushup(u), u;
		}
	}
	assert(0);
}
int build(int tp) {
	for(int i = tp; i; i = Tree::son[i]) {
		for(int v : e[i]) if(v != Tree::fa[i] && v != Tree::son[i])
			fa[build(v)] = i;
		val[i][0][0] = val[i][0][1] = g[i][0];
		val[i][1][0] = g[i][1], val[i][1][1] = -inf;
	}
	top = 0;
	for(int i = tp; i; i = Tree::son[i])
		stk[++top] = i, tsz[top] = Tree::siz[i] - Tree::siz[Tree::son[i]];
	int tmp = build2(1, top);
	isrt[tmp] = 1;
	return tmp;
}
void modify(int x, int y) {
	val[x][1][0] += y - a[x], a[x] = y;
	for(int i = x; i; i = fa[i]) {
		Matrix pre = sum[i];
		pushup(i);
		Matrix suf = sum[i];
		if(isrt[i] && fa[i]) {
			int f = fa[i];
			val[f][0][0] += max(suf[0][0], suf[1][0]) - max(pre[0][0], pre[1][0]);
			val[f][0][1] = val[f][0][0];
			val[f][1][0] += suf[0][0] - pre[0][0];
		}
	}
}

}

signed main() {
	n = read(), m = read();
	rep(i, 1, n) a[i] = read();
	rep(i, 2, n) {
		int x = read(), y = read();
		e[x].pb(y), e[y].pb(x);
	}
	Tree::dfs(1, 0);
	bst::rt = bst::build(1);
	int lastans = 0;
	while(m--) {
		int x = read() ^ lastans, y = read();
		bst::modify(x, y);
		printf("%d\n", lastans = max(bst::sum[bst::rt][0][0], bst::sum[bst::rt][1][0]));
	}
	return 0;
}

到现在为止三种维护ddp常用的方法都已经介绍完毕,用哪种请读者自己选择。

从我个人角度不建议写树剖,因为复杂度多一只 \(\log\),可能被卡。而且其实树剖+线段树是三种写法里码量最大的。

给出我的实现下,这两道题在洛谷评测的最大测试点用时:

P4719(普通版) P4751(加强版)
树剖+线段树 179ms >3.7s(TLE)
LCT 77ms 2.95s
全局平衡二叉树 55ms 1.42s

毕竟每个人的实现都会有偏差,但是总体是可以看出每种方法的常数差别,在加强版体现的尤为突出。

小总结

上面解题的步骤其实是比较清晰的,也是一般做 ddp 题的步骤:

  • 写出不修改情况下的状态转移方程

  • 分离轻重儿子的贡献

  • 把转移写成矩阵

  • 大力码码码

一般都会在最开始暴力跑一趟树形dp求出不包括重儿子的答案塞进矩阵。

修改一般采用的方法是,消去原贡献,加入新贡献。

查询只要理解ddp本质都没问题。

例题二

P6021 洪水

小清新题,和模板没太大区别。

题意:给一棵树,每个点有点权 \(a_i\),每次询问:在某个子树内以点权为代价删除(堵上)一些点使得根与子树内所有叶子不连通的最小代价;带单点修改。

不带修情况

\(f_i\) 表示把以 \(i\) 为根的子树完全堵上的答案。

\[f_u=\min(a_u,\sum_{v \in son(u)} f_v) \]

转移简洁并且单点修改使我们想到使用 ddp 来维护。

分离轻重儿子(这里的 \(son(u)\) 表示 \(u\) 的重儿子):

\[f_u=\min(a_u,f_{wson(u)}+\sum_{v \not= wson(u)}f_v) \]

\[f_u=\min(a_u,f_{wson(u)}+g_u) \]

其中 \(g_u\) 是轻儿子的贡献

把转移方程写成矩阵

重定义广义矩阵乘法:

\[res_{i,j}=\min(a_{i,k}+b_{k,j}) \]

构造矩阵:

如果矩阵是一维的

\[\begin{bmatrix} x&y \end{bmatrix} * \begin{bmatrix} f_{wson(u)} \end{bmatrix} =\begin{bmatrix} f_u \end{bmatrix} \]

那么 \(x=g_u,y=a_u-f_{son(u)}\)

发现左矩阵做了一个重儿子的东西,不能做ddp。

考虑再加一维:

\[\begin{bmatrix} x&y\\ z&w \end{bmatrix} * \begin{bmatrix} f_{wson(u)}\\ p \end{bmatrix} = \begin{bmatrix} f_u\\ q \end{bmatrix} \]

因为 \(a_u=a_u+0\) ,考虑直接把 \(p,q\) 设成 \(0\) , 那么根据转移方程,\(x=g_u,y=a_u\) ,这样 \(f_u\) 已经被正确表示了。

但是底下的 \(q\) 在矩乘之后不一定是 \(0\) ,考虑通过 \(z,w\) 来维护 \(q\)

直接展开 \(q\)\(q=\min(z+f_{son(u)},w)\)\(p=0\) 就不写了)

\(f_{son(u)}\) 是非负的,所以 \(w=0,z\ge 0\) 即可

矩阵构造完毕!

\[\begin{bmatrix} g_u&a_u\\ 0&0 \end{bmatrix} * \begin{bmatrix} f_{wson(u)}\\ 0 \end{bmatrix} = \begin{bmatrix} f_u\\ 0 \end{bmatrix} \]

封死一颗子树的代价就是 \(f_u\)

这里要提一个小细节,就是叶子节点没有轻儿子的时候 \(g\) 怎么办。

我想到了两种处理方法:

一种处理方法是把 \(g_u\) 设为 \(a_u\),因为叶子节点在ddp的时候要满足 \(f_u=g_u\)

还有一种方法就是把 \(g_u\) 设成 \(+\infty\),直接禁止从“封死所有子树”这种方法的转移。并且矩阵左上角不和自己的 \(a\)\(\min\),只维护轻子树的 dp值 和,调用 dp 值的时候再和自己的 \(a\)\(\min\)

第一种在树剖的时候比较好搞;如果用 LCT 维护我暂时没想到什么好的维护方法所以用了第二种。

修改

\(x\)\(v\)

直接影响的就是这个节点的 \(a_u\)

但是叶子节点还得同时更改 \(g_u\),千万别忘。

至于轻边父亲的修改,减去原贡献加上新贡献就好了。

LCT 同理,不过在修改一个节点矩阵的时候要先 accesssplay,这时候修改它对于任何节点都是没有影响的。

查询

树剖直接用线段树把这个点到重链底端的矩阵全部乘起来就好了。

LCT 比较特殊。要把这个节点在 原树 上的父亲 access 一下,再 splay 这个节点,这样子这个节点的信息就是这颗子树的信息。

我access在lct上的父亲调了一晚上(((

矩乘只是 \(2\times2\times2\) ,建议手动暴力展开,可以快非常多。

因为树剖代码是在之前的代码上改的,怕有些地方与描述不同,因此重写了一份 LCT 的代码

树剖版代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double db;
#define pb(x) push_back(x)
#define mkp(x,y) make_pair(x,y)
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
//char buf[1<<21],*p1=buf,*p2=buf;
inline int read() {
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
	return x*f;
}
int rdc(){
	char ch=getchar();
	while(ch!='Q'&&ch!='C')ch=getchar();
	return ch=='Q';
}
const int N=200005;
const LL inf=1e14;
const int T=N<<2;
int n,dp[N];
LL a[N];
int head[N],num_edge;
int dfn[N],rev[N],tmr,fa[N],siz[N],son[N],top[N],ed[N];
struct edge{
	int nxt,to;
}e[N<<1];
void addedge(int fr,int to){
	++num_edge;
	e[num_edge].nxt=head[fr];
	e[num_edge].to=to;
	head[fr]=num_edge;
}
struct Matrix {
	LL a[2][2];
	Matrix(){a[0][0]=a[0][1]=a[1][0]=a[1][1]=inf;}
	LL*operator[](const int&k){return a[k];}
	Matrix operator * (const Matrix&b){
		Matrix res;
		res.a[0][0] = min(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
		res.a[0][1] = min(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
		res.a[1][0] = min(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
		res.a[1][1] = min(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
		return res;
	}
	void print(){
		printf("%lld %lld\n%lld %lld\n\n",a[0][0],a[0][1],a[1][0],a[1][1]);
	}
}mat[N],val[T];
void dfs1(int u,int ft){
	if(!e[head[u]].nxt)return dp[u]=a[u],siz[u]=1,void();
	LL sum=0;siz[u]=1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==ft)continue;
		fa[v]=u,dfs1(v,u),sum+=dp[v],siz[u]+=siz[v];
		if(siz[v]>siz[son[u]])son[u]=v;
	}
	dp[u]=min(sum,a[u]);
}
void dfs2(int u,int tp){
	top[u]=tp,dfn[u]=++tmr,rev[tmr]=u,ed[tp]=tmr;
	mat[u][0][0]=0,mat[u][0][1]=a[u],
	mat[u][1][0]=0,mat[u][1][1]=0;
	if(!son[u])return mat[u][0][0]=a[u],void();
	dfs2(son[u],tp);
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa[u]||v==son[u])continue;
		dfs2(v,v),mat[u][0][0]+=dp[v];
	}
}
#define lc (p<<1)
#define rc (p<<1|1)
void pushup(int p){val[p]=val[lc]*val[rc];}
void build(int l,int r,int p=1){
	if(l==r)return val[p]=mat[rev[l]],void();
	int mid=(l+r)>>1;
	build(l,mid,lc),build(mid+1,r,rc);
	pushup(p);
}
Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
	if(ql<=l&&r<=qr)return val[p];
	int mid=(l+r)>>1;
	if(qr<=mid)return query(ql,qr,l,mid,lc);
	if(mid<ql)return query(ql,qr,mid+1,r,rc);
	return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
}
void change(int pos,int l=1,int r=n,int p=1){
	if(l==r)return val[p]=mat[rev[pos]],void();
	int mid=(l+r)>>1;
	if(pos<=mid)change(pos,l,mid,lc);
	else change(pos,mid+1,r,rc);
	pushup(p);
}
void update(int x,int v){
	mat[x][0][1]+=v,a[x]+=v;
	if(siz[x]==1)mat[x][0][0]+=v;
	while(x){
		Matrix lst=query(dfn[top[x]],ed[top[x]]);
		change(dfn[x]);
		Matrix now=query(dfn[top[x]],ed[top[x]]);
		x=fa[top[x]];
		mat[x][0][0]+=now[0][0]-lst[0][0];
	}
}
signed main(){
	n=read();
	for(int i=1;i<=n;++i)a[i]=read();
	for(int i=1;i<n;++i){
		int x=read(),y=read();
		addedge(x,y),addedge(y,x);
	}
	dfs1(1,0),dfs2(1,1),build(1,n);
	
	for(int m=read();m;--m){
		int opt=rdc(),x=read();
		if(opt){
			Matrix t=query(dfn[x],ed[top[x]]);
			printf("%lld\n",t[0][0]);
		}
		else update(x,read());
	}
	return 0;
}
LCT 版代码
#include <bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long LL;
#define fi first
#define se second
#define pb push_back
#define mkp make_pair
#define sz(v) (int)(v).size()
#define rep(i, x, y) for(int i = x, i##end = y; i <= i##end; ++i)
#define per(i, x, y) for(int i = x, i##end = y; i >= i##end; --i)
template<class T> inline bool ckmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
template<class T> inline bool ckmin(T &x, T y) { return x > y ? x = y, 1 : 0; }
inline int read() {
    int x = 0, f = 1; char ch = getchar();
    while(!isdigit(ch)) { if(ch == '-') f = 0; ch = getchar(); }
    while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
    return f ? x : -x;
}
inline int rdch() {
	char ch = getchar();
	while(ch != 'Q' && ch != 'C') ch = getchar();
	return ch == 'Q';
}
const int N = 200005;
const LL inf = 1e14;
int n, m, lef[N], tfa[N];
LL a[N], dp[N];
vector<int> e[N];
int fa[N], ch[N][2];
struct Matrix {
	LL a[2][2];
	Matrix() { a[0][0] = a[0][1] = a[1][0] = a[1][1] = inf; }
	inline LL* operator [](const int &k) { return a[k]; }
	inline Matrix operator * (const Matrix &t) const {
		Matrix res;
		res.a[0][0] = min(a[0][0] + t.a[0][0], a[0][1] + t.a[1][0]);
		res.a[0][1] = min(a[0][0] + t.a[0][1], a[0][1] + t.a[1][1]);
		res.a[1][0] = min(a[1][0] + t.a[0][0], a[1][1] + t.a[1][0]);
		res.a[1][1] = min(a[1][0] + t.a[0][1], a[1][1] + t.a[1][1]);
		return res;
	}
} val[N], sum[N];
inline bool nroot(int x) { return ch[fa[x]][0] == x || ch[fa[x]][1] == x; }
inline void pushup(int x) {
	sum[x] = val[x];
	if(ch[x][0]) sum[x] = sum[ch[x][0]] * sum[x];
	if(ch[x][1]) sum[x] = sum[x] * sum[ch[x][1]];
}
inline void rotate(int x) {
	int y = fa[x], z = fa[y], k = ch[y][1] == x, w = ch[x][!k];
	if(nroot(y)) ch[z][ch[z][1] == y] = x;
	ch[x][!k] = y, ch[y][k] = w;
	fa[w] = y, fa[y] = x, fa[x] = z;
	pushup(y);
}
inline void splay(int x) {
	while(nroot(x)) {
		int y = fa[x], z = fa[y];
		if(nroot(y)) rotate((ch[z][1] == y) ^ (ch[y][1] == x) ? x : y);
		rotate(x);
	}
	pushup(x);
}
inline void access(int x) {
	for(int y = 0; x; x = fa[y = x]) {
		splay(x);
		if(y) val[x][0][0] -= min(sum[y][0][0], sum[y][0][1]);
		if(ch[x][1]) val[x][0][0] += min(sum[ch[x][1]][0][0], sum[ch[x][1]][0][1]);
		ch[x][1] = y, pushup(x);
	}
}
void dfs(int u, int ft) {
	lef[u] = 1;
	for(int v : e[u]) if(v != ft)
		tfa[v] = fa[v] = u, dfs(v, u), dp[u] += dp[v], lef[u] = 0;
	if(lef[u]) dp[u] = a[u];
	val[u][0][0] = lef[u] ? inf : dp[u], val[u][0][1] = a[u];
	val[u][1][0] = val[u][1][1] = 0;
	pushup(u);
	ckmin(dp[u], a[u]);
}
signed main() {
	n = read();
	rep(i, 1, n) a[i] = read();
	rep(i, 2, n) {
		int x = read(), y = read();
		e[x].pb(y), e[y].pb(x);
	}
	dfs(1, 0);
	for(m = read(); m--; ) {
		int op = rdch(), x = read();
		if(op) {
			if(tfa[x]) access(tfa[x]);
			splay(x), printf("%lld\n", min(sum[x][0][0], sum[x][0][1]));
		} else {
			int y = read();
			access(x), splay(x);
			val[x][0][1] += y;
			pushup(x);
		}
	}
}

例题三

P5024 [NOIP2018 提高组] 保卫王国

分别钦定两个城市必取或者必不取的最小独立集。

对于一定驻扎,把点权设为 \(-\infty\)。对于一定不驻扎,点权设为 \(+\infty\)

然后跑最小独立集即可。

最后输出的时候加或减一下之前偏移的 \(\infty\)

可以偷懒把点权取反拉最大独立集的板子(

但是这里修改带四倍常数,用树剖写时间非常紧,uoj 上根本过不去,建议写全局平衡二叉树,我懒得重写了。

Code
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef double db;
#define pb(x) push_back(x)
#define mkp(x,y) make_pair(x,y)
//#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
//char buf[1<<21],*p1=buf,*p2=buf;
inline int read() {
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch))x=x*10+(ch^48),ch=getchar();
	return x*f;
}
const int N=100005;
const int M=N<<2;
const LL inf=1e12;
int n,m;
LL p[N];
int siz[N],dfn[N],tmr,son[N],fa[N],top[N],rev[N],ed[N],f[N][2];
char cynAKIOI[114514];
struct edge{
	int nxt,to;
}e[N<<1];
int head[N],num_edge;
void addedge(int fr,int to){
	++num_edge;
	e[num_edge].nxt=head[fr];
	e[num_edge].to=to;
	head[fr]=num_edge;
}
struct Matrix{
	LL p[2][2];
	Matrix(){p[0][0]=p[0][1]=p[1][0]=p[1][1]=-inf;}
	LL*operator[](const int&k){return p[k];}
	Matrix operator * (const Matrix&b){
		Matrix res;
//		for(int i=0;i<2;++i)
//			for(int j=0;j<2;++j)
//				for(int k=0;k<2;++k)
//					res.p[i][j]=max(res.p[i][j],p[i][k]+b.p[k][j]);
		res[0][0]=max(p[0][0]+b.p[0][0],p[0][1]+b.p[1][0]);
		res[0][1]=max(p[0][0]+b.p[0][1],p[0][1]+b.p[1][1]);
		res[1][0]=max(p[1][0]+b.p[0][0],p[1][1]+b.p[1][0]);
		res[1][1]=max(p[1][0]+b.p[0][1],p[1][1]+b.p[1][1]);
		return res;
	}
}mat[N],val[M];
void dfs1(int u,int ft){
	siz[u]=1,f[u][1]=p[u];
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==ft)continue;
		fa[v]=u,dfs1(v,u),siz[u]+=siz[v];
		if(siz[v]>siz[son[u]])son[u]=v;
		f[u][0]+=max(f[v][0],f[v][1]);
		f[u][1]+=f[v][0];
	}
}
void dfs2(int u,int tp){
	top[u]=tp,dfn[u]=++tmr,rev[tmr]=u,ed[tp]=tmr;
	if(son[u])dfs2(son[u],tp);
	LL g[2];g[0]=0,g[1]=p[u];
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==son[u]||v==fa[u])continue;
		dfs2(v,v);
		g[0]+=max(f[v][0],f[v][1]);
		g[1]+=f[v][0];
	}
	mat[u][0][0]=g[0],mat[u][0][1]=g[0];
	mat[u][1][0]=g[1],mat[u][1][1]=-inf;
}
#define lc (p<<1)
#define rc (p<<1|1)
void pushup(int p){val[p]=val[lc]*val[rc];}
void build(int l,int r,int p){
	if(l==r)return val[p]=mat[rev[l]],void();
	int mid=(l+r)>>1;
	build(l,mid,lc),build(mid+1,r,rc);
	pushup(p);
}
Matrix query(int ql,int qr,int l=1,int r=n,int p=1){
	if(ql<=l&&r<=qr)return val[p];
	int mid=(l+r)>>1;
	if(qr<=mid)return query(ql,qr,l,mid,lc);
	if(mid<ql)return query(ql,qr,mid+1,r,rc);
	return query(ql,qr,l,mid,lc)*query(ql,qr,mid+1,r,rc);
}
void change(int pos,int l=1,int r=n,int p=1){
	if(l==r)return val[p]=mat[rev[l]],void();
	int mid=(l+r)>>1;
	if(pos<=mid)change(pos,l,mid,lc);
	else change(pos,mid+1,r,rc);
	pushup(p);
}
void update(int x,LL v){
	mat[x][1][0]+=v,p[x]+=v;
	while(x){
		Matrix lst=query(dfn[top[x]],ed[top[x]]);
		change(dfn[x]);
		Matrix now=query(dfn[top[x]],ed[top[x]]);
		x=fa[top[x]];
		mat[x][0][0]+=max(now[0][0],now[1][0])-max(lst[0][0],lst[1][0]);
		mat[x][0][1]=mat[x][0][0];
		mat[x][1][0]+=now[0][0]-lst[0][0];
	}
}
signed main(){
	n=read(),m=read(),scanf("%s",cynAKIOI);
	for(int i=1;i<=n;++i)p[0]+=(p[i]=read());
	for(int i=1;i<n;++i){
		int x=read(),y=read();
		addedge(x,y),addedge(y,x);
	}
	dfs1(1,0),dfs2(1,1),build(1,n,1);
	while(m--){
		int a=read(),x=read(),b=read(),y=read();
		LL ad1=x?-inf:inf,out1=x?0:inf;
		LL ad2=y?-inf:inf,out2=y?0:inf;
		update(a,ad1),update(b,ad2);
		Matrix res=query(dfn[1],ed[1]);
		LL out=p[0]-max(res[0][0],res[1][0])+out1+out2;
		out<inf?printf("%lld\n",out):puts("-1");
		update(a,-ad1),update(b,-ad2);
	}
	return 0;
}

例题四

P3781 [SDOI2017]切树游戏

先考虑不带修改的情况如何dp。

\(dp(u,msk)\) 表示以 \(u\) 为根的联通子树 \(\operatorname{xor}\) 起来为 \(msk\) 的方案数。

转移是:

\[f(u, msk) = \sum_{v\in son(u),X\oplus Y = msk}f(u, X)*f(v,Y) \]

暴力转移是 \(O(m^2)\),非常明显可以 FWT 优化成 \(O(m\log m)\)

统计答案的时候,假设询问 \(k\),那么就是 \(\sum_{i=1}^{n} dp(i,k)\)

我觉得这里还是有必要提一下暴力 dp 的边界以及转移的细节。

我一开始写的边界处理是:\(dp(u,w_u)=1\),在把所有孩子合并上来之后再给 \(dp(u,0)\) 加一,这样它父亲调用它的时候那个 \(0\) 就相当于不选自己。

凭直觉就知道在这种鬼地方多个类似 if 的东西可能非常难办。以及 FWT 和 IFWT 的位置可能影响我们维护修改的难度。还有我们统计答案的方式是遍历所有节点而非在单一节点统计答案。这些问题在一开始就得解决。

以下记 \(\hat{a}\) 表示 \(a\) FWT 之后的数组。

首先解决统计答案的问题。

考虑记 \(g(u,msk)=f(u,msk)+\sum_{v\in_{son(u)}}f(v,msk)\)

这样子我们调用 \(g(1,msk)\) 就能得到整颗树的答案了。

接下去看怎么把转移写简洁。

最后单独给 \(0\) 加一肯定要去掉。那么在转移方程后加一项就行了

\[f(u, msk) = \sum_{v\in son(u),X\oplus Y = msk}f(u, X)*f(v,Y)+f(u,msk) \]

FWT 之后有

\[\hat{f}(u,msk)=\hat{f}(u,msk)\hat{f}(v,msk)+\hat{f}(u,msk)=\hat{f}(u,msk)(\hat{f}(v,msk)+1) \]

于是这个转移可以写的非常简洁:

\[\hat{f}(u,msk)=\hat{w}(u,msk)\prod_{v\in son(u)} (\hat{f}(v,msk)+1) \]

\(\hat{w}(u)\) 表示这个点的点权 FWT 之后的序列。

考虑 \(g\) 怎么搞。如果 IFWT 回去再统计又会使转移非常麻烦。

注意到点值是可以直接加的,那不妨维护 \(\hat{g}\),最后 IFWT 回去输出答案。

\[\hat{g}(u,msk)=\hat{f}(u,msk)+\sum_{v\in son(u)}\hat{g}(v,msk) \]

现在转移方程非常简洁了,只不过复杂度是 \(O(qnm\log m)\),考虑怎么优化。

注意以下的 \(f,g\) 全部定义为多项式,乘法定义为按位相乘。

考虑 ddp。分离轻重儿子:

\[\hat{f}(u)=(\hat{f}(wson_u)+1)\hat{w}(u)\prod_{v\in son(u),v\not=wson(u)}(\hat{f}(v)+1)\\ \hat{g}(u)=\hat{f}(u)+\hat{g}(wson_u)+\sum_{v\in won(u),v\not=wson(u)}\hat{g}(v) \]

\[\hat{F}(u)=\hat{w}(u)\prod_{v\in son(u),v\not=wson(u)}(\hat{f}(v)+1)\\ \hat{G}(u)=\sum_{v\in won(u),v\not=wson(u)}\hat{g}(v) \]

那么 dp 就写成了下面的形式

\[\hat{f}(u)=(\hat{f}(wson_u)+1)\hat{F}(u)\\=\hat{F}(u)+\hat{F}(u)\hat{f}(wson_u)\\ \hat{g}(u)=\hat{f}(u)+\hat{g}(wson_u)+\hat{G}(u)\\=\hat{F}(u)+\hat{F}(u)\hat{f}(wson_u)+\hat{g}(wson_u)+\hat{G}(u) \]

然后就构造矩阵转移

\[\begin{bmatrix} \hat{F}(u) & 0 & \hat{F}(u)\\ \hat{F}(u)& 1 & \hat{F}(u)+\hat{G}(u)\\ 0 & 0 & 1 \end{bmatrix} * \begin{bmatrix} \hat{f}(wson_u)\\ \hat{g}(wson_u)\\ 1 \end{bmatrix} =\begin{bmatrix} \hat{f}(u)\\ \hat{g}(u)\\ 1 \end{bmatrix} \]

修改的时候只需要跳重链修改 \(\hat{F},\hat{G}\) 就好,就是消去原贡献加入现在的贡献。

但是 \(F\) 消贡献是除掉一个东西,并且 XOR 的 FWT 是可以出负数的,加上模数非常小,很有可能除一个 \(0\) 下去(样例就是),看起来非常棘手。

事实上这个处理非常简单:对于每个节点开桶存乘了几个 \(0\),除以 \(0\) 的时候操作桶就行了。

复杂度是 \(O(qm(\log n+\log m))\)

但是矩阵乘法带 \(27\) 常数,全局平衡二叉树带 \(2\) 倍常数,带进去一算是惊人的 1e9,加上大量封装,根本过不去。

这时候有个小 trick,有些矩阵矩乘之后常数不变,这个矩阵也是这样。

\[\begin{bmatrix} a_1 & 0 & c_1\\ b_1 & 1 & d_1\\ 0 & 0 & 1 \end{bmatrix} * \begin{bmatrix} a_2 & 0 & c_2\\ b_2 & 1 & d_2\\ 0 & 0 & 1 \end{bmatrix} = \begin{bmatrix} a_1a_2 & 0 & a_1c_2+c_1\\ b_1a_2+b_2 & 1 & b_1c_2+d_2+d_1\\ 0 & 0 & 1 \end{bmatrix} \]

于是只用维护四个值,常数就从 \(27\) 降到了 \(8\)

到此为止思路结束了,码代码就靠自己了(逃

不过这题别写树剖,多个 \(\log\) 运算量差不多是 1e9,加上洛谷有个毒瘤加了组对着树剖卡的数据,基本不用想过。

Code
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
	int x=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
	while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
	return f?x:-x;
}
inline int rdch() {
	char ch = getchar();
	while(ch != 'Q' && ch != 'C') ch = getchar();
	return ch == 'Q';
}
const int N = 30005;
const int mod = 10007;
const int iv2 = (mod + 1) >> 1;
int inv[N];
int n, m, w[N];
inline int qpow(int n, int k) {
	int res = 1;
	for(; k; k >>= 1, n = n * n % mod)
		if(k & 1) res = res * n % mod;
	return res;
}
inline int add(int x, int y) { return (x += y) >= mod ? x - mod : x; }
inline int sub(int x, int y) { return (x -= y) < 0 ? x + mod : x; }
struct pint {
	int v, c;
	pint() { v = 1, c = 1; }
	pint(int v_) {
		if(!v_) v = 1, c = 1;
		else v = v_, c = 0;
	}
	inline int val() const { return c ? 0 : v; }
	friend pint operator * (pint a, const int &b) {
		if(!b) return ++a.c, a;
		else return (a.v *= b) %= mod, a;
	}
	friend pint operator / (pint a, const int &b) {
		if(!b) return --a.c, a;
		else return (a.v *= inv[b]) %= mod, a;
	}
};
inline vector<int> change(const vector<pint> &a) {
	vector<int> res(m);
	for(int i = 0; i < m; ++i) res[i] = a[i].val();
	return res;
}
inline vector<pint> operator * (const vector<pint> &a, const vector<int> &b) {
	vector<pint> res(m);
	for(int i = 0; i < m; ++i) res[i] = a[i] * b[i];
	return res;
}
inline vector<pint> operator / (const vector<pint> &a, const vector<int> &b) {
	vector<pint> res(m);
	for(int i = 0; i < m; ++i) res[i] = a[i] / b[i];
	return res;
}
inline vector<int> operator + (const vector<int> &a, const vector<int> &b) {
	vector<int> res(m);
	for(int i = 0; i < m; ++i) res[i] = add(a[i], b[i]);
	return res;
}
inline vector<int> operator - (const vector<int> &a, const vector<int> &b) {
	vector<int> res(m);
	for(int i = 0; i < m; ++i) res[i] = sub(a[i], b[i]);
	return res;
}
inline vector<int> operator * (const vector<int> &a, const vector<int> &b) {
	vector<int> res(m);
	for(int i = 0; i < m; ++i) res[i] = a[i] * b[i] % mod;
	return res;
}
inline vector<int> addone(vector<int> a) {
	for(int i = 0; i < m; ++i) a[i] = add(a[i], 1);
	return a;
}
inline vector<int> XOR(vector<int> a) {
	for(int i = 1; i < m; i <<= 1)
		for(int j = 0; j < m; j += i << 1)
			for(int k = 0; k < i; ++k) {
				int X = a[j + k], Y = a[i + j + k];
				a[j + k] = add(X, Y), a[i + j + k] = sub(X, Y);
			}
	return a;
}
inline vector<int> IXOR(vector<int> a) {
	for(int i = 1; i < m; i <<= 1)
		for(int j = 0; j < m; j += i << 1)
			for(int k = 0; k < i; ++k) {
				int X = a[j + k], Y = a[i + j + k];
				a[j + k] = (X + Y) * iv2 % mod, a[i + j + k] = (X - Y + mod) * iv2 % mod;
			}
	return a;
}
int rt, tfa[N], fa[N], cnz[N], siz[N], son[N], stk[N], top, ch[N][2], tsz[N];
bool isrt[N];
vector<int> e[N];
struct Matrix {
	vector<int> a00, a10, a02, a12;
	inline Matrix operator * (const Matrix &t) const {
		Matrix res;
		res.a00 = a00 * t.a00;
		res.a10 = a10 * t.a00 + t.a10;
		res.a02 = a00 * t.a02 + a02;
		res.a12 = a10 * t.a02 + a12 + t.a12;
		return res;
	}
} val[N], sum[N];
Matrix mat;
vector<pint> F[N];
vector<int> G[N], ans, f[N], g[N], a[N];
inline void pushup(int u) {
	sum[u] = val[u];
	if(ch[u][0]) sum[u] = sum[ch[u][0]] * sum[u];
	if(ch[u][1]) sum[u] = sum[u] * sum[ch[u][1]];
}
inline void get(int u) {
	val[u].a00 = val[u].a10 = val[u].a02 = val[u].a12 = change(F[u]);
	val[u].a12 = val[u].a12 + G[u];
}
void dfs(int u, int ft) {
	f[u].resize(m), g[u].resize(m);
	f[u][w[u]] = 1, f[u] = XOR(f[u]);
	a[u] = f[u];
	
	siz[u] = 1;
	for(int v : e[u]) if(v != ft) {
		tfa[v] = u, dfs(v, u), siz[u] += siz[v];
		if(siz[v] > siz[son[u]]) son[u] = v;
		f[u] = f[u] * addone(f[v]), g[u] = g[u] + g[v];
	}
	g[u] = g[u] + f[u];
	
	F[u].resize(m), G[u].resize(m);
	for(int i = 0; i < m; ++i) F[u][i] = a[u][i];
	for(int v : e[u]) if(v != ft && v != son[u]) {
		F[u] = F[u] * addone(f[v]), G[u] = G[u] + g[v];
	}
	get(u);
}
inline int build2(int l, int r) {
	if(l > r) return 0;
	int ALL = 0, now = 0;
	for(int i = l; i <= r; ++i) ALL += tsz[i];
	for(int i = l; i <= r; ++i) {
		now += tsz[i];
		if(now << 1 >= ALL) {
			int u = stk[i];
			fa[ch[u][0] = build2(l, i - 1)] = u;
			fa[ch[u][1] = build2(i + 1, r)] = u;
			return pushup(u), u;
		}
	}
	return -1;
}
int build(int tp) {
	for(int i = tp; i; i = son[i])
		for(int v : e[i]) if(v != son[i] && v != tfa[i])
			fa[build(v)] = i;
	top = 0;
	for(int i = tp; i; i = son[i]) stk[++top] = i, tsz[top] = siz[i] - siz[son[i]];
	int tmp = build2(1, top);
	return isrt[tmp] = 1, tmp;
}
void modify(int x, int y) {
	F[x] = F[x] / a[x];
	memset(a[x].data(), 0, m << 2);
	a[x][y] = 1, w[x] = y, a[x] = XOR(a[x]);
	F[x] = F[x] * a[x], get(x);
	for(; x; x = fa[x]) {
		if(fa[x] && isrt[x]) {
			F[fa[x]] = F[fa[x]] / addone(sum[x].a02), G[fa[x]] = G[fa[x]] - sum[x].a12;
			pushup(x);
			F[fa[x]] = F[fa[x]] * addone(sum[x].a02), G[fa[x]] = G[fa[x]] + sum[x].a12;
			get(fa[x]);
		} else pushup(x);
	}
}
signed main() {
	inv[1] = 1;
	for(int i = 2; i < mod; ++i) inv[i] = inv[mod % i] * (mod - mod / i) % mod;
	n = read(), m = read();
	rep(i, 1, n) w[i] = read();
	rep(i, 2, n) {
		int x = read(), y = read();
		e[x].pb(y), e[y].pb(x);
	}
	dfs(1, 0);
	rt = build(1);
	ans = IXOR(sum[rt].a12);
	for(int q = read(); q; --q) {
		int op = rdch(), x = read();
		if(op == 1) {
			printf("%d\n", ans[x]);
		} else {
			int y = read();
			modify(x, y);
			ans = IXOR(sum[rt].a12);
		}
	}
	return 0;
}

参考资料

shadowice1984 P3781 的题解

Tweetuzki P4719 的题解

Great_Influence 对全局平衡二叉树的讲解

posted @ 2021-05-08 20:26  zzctommy  阅读(179)  评论(0编辑  收藏  举报