树哈希
树哈希
1.1 定义
1.1.1 同构树
我们定义,如果两颗有根树,交换其中节点的儿子后,两棵树形态一致,称这样的两棵树为同构树。
树哈希能做的就是判断两棵树是否同构。
1.1.2 哈希方法
树哈希十分灵活,也就是说你可以设计出你自己的哈希方式。但是显然,你设计的并不一定能满足正确性,可能被卡掉。
下面介绍一种不易被卡掉的树哈希方式。
我们设 \(hs_i\) 表示根为 \(i\) 的子树哈希值,那么我们现在需要设计一种多元函数 \(f\),然后按下面公式计算:
\[hs_i=f(\{hs_j\mid j\in \text{son}_i\})
\]
在 OI-wiki,多元函数为:
\[f(S)=(c+\sum_{x\in S}g(x))\bmod P
\]
其中 \(g(x)\) 又是一个一元函数。同常情况下,\(c\) 取 \(1\),\(P\) 取 \(2^{64}\)(也就是直接自然溢出)。
这种哈希有一个极强的优点:可以十分简单的实现换根 dp。
1.2 例子
解决树哈希相关问题的基本思路就是换根 dp。
1.2.1 换根 dp
以 [BJOI2015] 树的同构 为例,讲解树哈希如何换根。
首先我们计算出所有 \(hs_i\)。接下来我们令 \(rt_i\) 为根为 \(i\) 时的哈希值。由上面的公式可以得到:
\[rt_i=hs_i+g(rt_{fa}-g(hs_i))
\]
是不是十分的简单?
最后放一下这道题的代码:
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const int Maxn = 2e5 + 5;
int head[Maxn], edgenum;
struct node {
int nxt, to;
}edge[Maxn];
void add(int from, int to) {
edge[++edgenum].nxt = head[from];
edge[edgenum].to = to;
head[from] = edgenum;
}
ull rad, hs[Maxn], rt[Maxn];
ull shift(ull x) {
x ^= rad;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
x ^= rad;
return x;
}
int rot;
void dfs1(int p, int fa) {
hs[p] = 1;
for(int i = head[p]; i; i = edge[i].nxt) {
int to = edge[i].to;
dfs1(to, p);
hs[p] += shift(hs[to]);
}
}
void dfs2(int p, int fa) {
for(int i = head[p]; i; i = edge[i].nxt) {
int to = edge[i].to;
rt[to] = hs[to] + shift(rt[p] - shift(hs[to]));
dfs2(to, p);
}
}
map <ull, int> has;
int m, n;
int main() {
ios::sync_with_stdio(0);
srand(114514);
rad = rand();
cin >> m;
for(int t = 1; t <= m; t++) {
memset(head, 0, sizeof head);
memset(edge, 0, sizeof edge);
memset(hs, 0, sizeof hs);
memset(rt, 0, sizeof rt);
edgenum = 0;
cin >> n;
for(int i = 1; i <= n; i++) {
int u;
cin >> u;
if(u != 0) {
add(u, i);
}
else {
rot = i;
}
}
dfs1(rot, 0);
rt[rot] = hs[rot];
dfs2(rot, 0);
ull hash = 1;
for(int i = 1; i <= n; i++) {
hash += shift(rt[i]);
}
if(has[hash]) {
cout << has[hash] << '\n';
}
else {
has[hash] = t;
cout << t << '\n';
}
}
return 0;
}
上面代码中,shift 就是 \(g(x)\) 函数。
1.2.2 其他树形 dp
以 [hdu6647]Bracket Sequences on Tree 为例。
首先,我们发现,两颗同构的树所得出的方案数是一样的。
那么我们考虑换根求树哈希值的思路。首先设 \(f_i\) 表示第一次遍历以 \(i\) 为根的方案。那么根据上面的结论,会有方程:
\[f_i=\frac{deg_i!}{\prod dif_i!}\prod f_{to}
\]
其中 \(dif_i\) 表示 \(i\) 的每个同构子树的数量,利用树哈希求解即可。
这样 dp 完一遍后,我们考虑换根。换根方式与哈希是相同的,排除掉子树贡献然后加上去即可。
参考代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int Maxn = 2e5 + 5;
const int Mod = 998244353;
int head[Maxn], edgenum;
struct node {
int nxt, to;
}edge[Maxn];
void add(int from, int to) {
edge[++edgenum].nxt = head[from];
edge[edgenum].to = to;
head[from] = edgenum;
}
int f[Maxn], g[Maxn];
int T;
int n;
int inv(int a) {
int b = Mod - 2, res = 1;
while(b) {
if(b & 1) {
res = (res * a) % Mod;
}
a = (a * a) % Mod;
b >>= 1;
}
return res;
}
ull rad, hs[Maxn], rt[Maxn];
ull F(ull x) {
x ^= rad;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
x ^= rad;
return x;
}
int dp[Maxn], cha[Maxn], du[Maxn];
map <ull, int> ma[Maxn];
map <ull, int> ans;
void dfs1(int x, int fa) {
hs[x] = 1;
dp[x] = 1;
for(int i = head[x]; i; i = edge[i].nxt) {
int to = edge[i].to;
if(to == fa) continue;
dfs1(to, x);
hs[x] += F(hs[to]);
(dp[x] *= dp[to]) %= Mod;
du[x]++;
ma[x][hs[to]]++;
}
(dp[x] *= f[du[x]]) %= Mod;
for(auto it : ma[x]) {
(dp[x] *= g[it.second]) %= Mod;
}
}
void dfs2(int x, int fa) {
for(int i = head[x]; i; i = edge[i].nxt) {
int to = edge[i].to;
if(to == fa) continue;
ull p = rt[x];
int q = cha[x];
rt[x] -= F(hs[to]);
cha[x] = cha[x] * inv(du[x]) % Mod * ma[x][hs[to]] % Mod * inv(dp[to]) % Mod;
rt[to] = hs[to] + F(rt[x]);
cha[to] = dp[to] * (++du[to]) % Mod * cha[x] % Mod * inv(++ma[to][rt[x]]) % Mod;
rt[x] = p, cha[x] = q;
dfs2(to, x);
}
ans[rt[x]] = cha[x];
}
void init() {
for(int i = 1; i <= n; i++) {
head[i] = 0;
edge[i] = {0, 0};
dp[i] = 0;
cha[i] = 0;
hs[i] = 0;
rt[i] = 0;
du[i] = 0;
ma[i].clear();
}
ans.clear();
edgenum = 0;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
srand(114514);
rad = rand();
f[0] = g[0] = 1;
for(int i = 1; i < Maxn; i++) {
f[i] = (f[i - 1] * i) % Mod;
g[i] = (g[i - 1] * inv(i)) % Mod;
}
cin >> T;
while(T--) {
init();
cin >> n;
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add(u, v), add(v, u);
}
dfs1(1, 0);
rt[1] = hs[1], cha[1] = dp[1];
dfs2(1, 0);
int cnt = 0;
for(auto i : ans) {
(cnt += i.second) %= Mod;
}
cout << cnt << '\n';
}
return 0;
}

浙公网安备 33010602011771号