树hash

判断树的同构,采用树hash的方式。

树hash定义在有根树上。判断无根树同构的时候,可以比较重心为根的hash值或者比较每个点为根的hash值。

h[x]表示x为根的子树的hash,g[x]表示x为根时全树的hash。

我采用的方法是

h[x] = 1 + ∑h[y] * p[siz[y]]

于是g[x] = g[fa] - h[x] * p[siz[x]] + h[x]

例题1: BJOI2015 树的同构

判断无根树同构,我是比较了每个点为根时的hash值。

  1 #include <bits/stdc++.h>
  2 
  3 typedef long long LL;
  4 const int N = 60, MO = 998244353;
  5 
  6 struct Edge {
  7     int nex, v;
  8 }edge[N << 1]; int tp;
  9 
 10 int e[N], n, m, turn, fr[N], p[1000010], top, siz[N], h[N], g[N];
 11 std::vector<int> v[N];
 12 bool vis[1000010];
 13 
 14 inline void add(int x, int y) {
 15     tp++;
 16     edge[tp].v = y;
 17     edge[tp].nex = e[x];
 18     e[x] = tp;
 19     return;
 20 }
 21 
 22 inline bool equal(int a, int b) {
 23     int len = v[a].size();
 24     if(len != v[b].size()) return false;
 25     for(int i = 0; i < len; i++) {
 26         if(v[a][i] != v[b][i]) return false;
 27     }
 28     return true;
 29 }
 30 
 31 inline void getp(int n) {
 32     for(int i = 2; i <= n; i++) {
 33         if(!vis[i]) {
 34             p[++top] = i;
 35         }
 36         for(int j = 1; j <= top && i * p[j] <= n; j++) {
 37             vis[i * p[j]] = 1;
 38             if(i % p[j] == 0) {
 39                 break;
 40             }
 41         }
 42     }
 43     return;
 44 }
 45 
 46 void DFS_1(int x, int f) {
 47     siz[x] = 1;
 48     h[x] = 1;
 49     for(int i = e[x]; i; i = edge[i].nex) {
 50         int y = edge[i].v;
 51         if(y == f) continue;
 52         DFS_1(y, x);
 53         h[x] = (h[x] + 1ll * h[y] * p[siz[y]] % MO) % MO;
 54         siz[x] += siz[y];
 55     }
 56     return;
 57 }
 58 
 59 void DFS_2(int x, int f, int V) {
 60     g[x] = (h[x] + 1ll * V * p[n - siz[x]] % MO) % MO;
 61     v[turn].push_back(g[x]);
 62     V = (1ll * V * p[n - siz[x]] % MO + 1) % MO;
 63     for(int i = e[x]; i; i = edge[i].nex) {
 64         int y = edge[i].v;
 65         if(y == f) {
 66             continue;
 67         }
 68         DFS_2(y, x, ((LL)V + h[x] - 1 - 1ll * h[y] * p[siz[y]] % MO + MO) % MO);
 69     }
 70     return;
 71 }
 72 
 73 int main() {
 74     getp(1000009);
 75     scanf("%d", &m);
 76     for(turn = 1; turn <= m; turn++) {
 77         scanf("%d", &n);
 78         tp = 0;
 79         memset(e + 1, 0, n * sizeof(int));
 80         for(int i = 1, x; i <= n; i++) {
 81             scanf("%d", &x);
 82             if(x) {
 83                 add(x, i);
 84                 add(i, x);
 85             }
 86         }
 87         DFS_1(1, 0);
 88         DFS_2(1, 0, 0);
 89         std::sort(v[turn].begin(), v[turn].end());
 90         /*for(int i = 0; i < n; i++) {
 91             printf("%d ", v[turn][i]);
 92         }
 93         puts("\n");*/
 94     }
 95 
 96     for(int i = 1; i <= m; i++) {
 97         fr[i] = i;
 98     }
 99     for(int i = 2; i <= m; i++) {
100         for(int j = 1; j < i; j++) {
101             if(equal(i, j)) {
102                 fr[i] = fr[j];
103                 break;
104             }
105         }
106     }
107     for(int i = 1; i <= m; i++) {
108         printf("%d\n", fr[i]);
109     }
110     return 0;
111 }
AC代码

例题2: JSOI2016 独特的树叶

对第一棵树的所有点为根的hash值建立set,然后枚举第二棵树,在set中查。

  1 #include <bits/stdc++.h>
  2 
  3 const int N = 100010, MO = 998244353;
  4 
  5 struct Edge {
  6     int nex, v;
  7 };
  8 
  9 std::set<int> st;
 10 int p[2000010], top, in[N], near[N];
 11 bool vis[2000010];
 12 
 13 inline int qpow(int a, int b) {
 14     int ans = 1;
 15     while(b) {
 16         if(b & 1) ans = 1ll * ans * a % MO;
 17         a = 1ll * a * a % MO;
 18         b = b >> 1;
 19     }
 20     return ans;
 21 }
 22 
 23 struct Tree {
 24     Edge edge[N << 1]; int tp;
 25     int e[N], h[N], g[N], siz[N], n;
 26     inline void init(int t) {
 27         n = t;
 28         tp = 0;
 29         memset(e + 1, 0, n * sizeof(int));
 30         return;
 31     }
 32     inline void add(int x, int y) {
 33         edge[++tp].v = y;
 34         edge[tp].nex = e[x];
 35         e[x] = tp;
 36         return;
 37     }
 38     void DFS_1(int x, int f) {
 39         siz[x] = 1;
 40         h[x] = 1;
 41         for(int i = e[x]; i; i = edge[i].nex) {
 42             int y = edge[i].v;
 43             if(y == f) {
 44                 continue;
 45             }
 46             DFS_1(y, x);
 47             siz[x] += siz[y];
 48             h[x] = (h[x] + 1ll * h[y] * p[siz[y]] % MO) % MO;
 49         }
 50         //printf("x = %d h[x] = %d \n", x, h[x]);
 51         return;
 52     }
 53     void DFS_2(int x, int f, int V) {
 54         g[x] = (h[x] + 1ll * V * p[n - siz[x]] % MO) % MO;
 55         //printf("x = %d v = %d g = %d \n", x, V, g[x]);
 56         for(int i = e[x]; i; i = edge[i].nex) {
 57             int y = edge[i].v;
 58             if(y == f) {
 59                 continue;
 60             }
 61             DFS_2(y, x, (g[x] - 1ll * h[y] * p[siz[y]] % MO + MO) % MO);
 62         }
 63         return;
 64     }
 65 }t0, t1;
 66 
 67 inline void getp(int n) {
 68     for(int i = 2; i <= n; i++) {
 69         if(!vis[i]) {
 70             p[++top] = i;
 71         }
 72         for(int j = 1; j <= top && i * p[j] <= n; j++) {
 73             vis[i * p[j]] = 1;
 74             if(i % p[j] == 0) {
 75                 break;
 76             }
 77         }
 78     }
 79     return;
 80 }
 81 
 82 int main() {
 83 
 84     getp(2000009);
 85 
 86     int n;
 87     scanf("%d", &n);
 88     t0.init(n);
 89     t1.init(n + 1);
 90     int x, y;
 91     for(int i = 1; i < n; i++) {
 92         scanf("%d%d", &x, &y);
 93         t0.add(x, y);
 94         t0.add(y, x);
 95     }
 96     for(int i = 1; i <= n; i++) {
 97         scanf("%d%d", &x, &y);
 98         t1.add(x, y);
 99         t1.add(y, x);
100         in[x]++;
101         in[y]++;
102         near[x] = y;
103         near[y] = x;
104     }
105 
106     t0.DFS_1(1, 0);
107     t0.DFS_2(1, 0, 0);
108     for(int i = 1; i <= n; i++) {
109         st.insert(t0.g[i]);
110         //printf("%d hash = %d \n", i, t0.g[i]);
111     }
112 
113     t1.DFS_1(1, 0);
114     t1.DFS_2(1, 0, 0);
115     for(int i = 1; i <= n + 1; i++) {
116         if(in[i] == 1) {
117             int x = (t1.g[near[i]] - 2 + MO) % MO;
118             if(st.find(x) != st.end()) {
119                 printf("%d\n", i);
120                 return 0;
121             }
122         }
123     }
124 
125     return 0;
126 }
AC代码

更多例题:

posted @ 2019-05-06 08:48  huyufeifei  阅读(...)  评论(...编辑  收藏