# 点分治小结

## 算法介绍

• 1.经过根节点$root$的路径
• 2.不经过根节点$$root$$的路径

void dfs(int u) {
vis[u] = 1;
ans += solve(u, 0); //所有情况
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to]) continue;
int v = e[i].to;
ans -= solve(v, e[i].v); //减掉不合法情况
//下面是找重心的代码，后面会解释为什么要找重心
now_sz = inf, root = 0; sz = siz[v];
find_root(v, 0);
dfs(root);
}
}


• 1.找一个根节点root
• 2.对root计算出d数组并计算答案
• 3.把root删了，对root的各个子树执行流程1,2

sz = siz[v];->sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];

## 例题：

### POJ1741 tree

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010

inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}

int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N];
struct edge {
int to, nxt, v;
}e[N<<1];

void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
}

int now_sz = inf, root = 0, sz;

void find_root(int u, int fa) {
siz[u] = 1;
int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to] || e[i].to == fa) continue;
int v = e[i].to;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}

int a[N], tot;
void get_dis(int u, int fa) {
a[++tot] = d[u];
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to] || e[i].to == fa) continue;
int v = e[i].to;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}

int solve(int u, int dis) {
d[u] = dis; tot = 0;
get_dis(u, u);
sort(a + 1, a + tot + 1);
int l = 1, r = tot, res = 0;
for(; l < r; ++l) {
while(l < r && a[l] + a[r] > k) --r;
if(l < r) res += r - l;
}
return res;
}

void dfs(int u) {
vis[u] = 1;
ans += solve(u, 0);
for(int i = head[u]; i; i = e[i].nxt) {
if(vis[e[i].to]) continue;
int v = e[i].to;
ans -= solve(v, e[i].v);
now_sz = inf, root = 0; sz = siz[v];
find_root(v, 0);
dfs(root);
}
}

int main() {
while(~scanf("%d%d", &n, &k) && n && k) {
ans = 0; cnt = 0;
memset(vis, 0, sizeof(vis));
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
dfs(1);
printf("%d\n", ans);
}
}


### BZOJ2152: 聪聪可可

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010

inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}

int n, k, d[N], cnt, head[N], ans;
int vis[N], siz[N], sum[3];
struct edge {
int to, nxt, v;
}e[N<<1];

void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
}

int now_siz, sz, root;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_siz) now_siz = res, root = u;
}

void get_dis(int u, int fa) {
sum[d[u]%3]++;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v] || v == fa) continue;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}

int solve(int u, int dis) {
d[u] = dis; sum[0] = sum[1] = sum[2] = 0;
get_dis(u, u);
return sum[0] * sum[0] + sum[1] * sum[2] * 2;
}

void dfs(int u) {
ans += solve(u, 0);
vis[u] = 1;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
ans -= solve(v, e[i].v);
now_siz = inf; sz = siz[v]; root = 0;
find_root(v, u);
dfs(root);
}
}

int main() {
in(n);
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
now_siz = inf; root = 0; sz = n;
find_root(1, 1);
dfs(root);
int now = n * n, g = __gcd(now, ans);
printf("%d/%d\n", ans / g, now / g);
}


### LuoguP3806 【模板】点分治1

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define inf 0x3f3f3f3f
#define ll long long
#define N 100010
#define lim 10000000

inline void in(int &x) {
x = 0; int f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
x *= f;
}

int top, n, m, d[N], cnt, head[N], ans[110];
int vis[N], siz[N], q[110], st[N], s[10000010];
struct edge {
int to, nxt, v;
}e[N<<1];

void ins(int u, int v, int w) {
e[++cnt] = (edge) {v, head[u], w};
}

int now_sz = inf, root, sz;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
res = max(res, siz[v]);
siz[u] += siz[v];
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}

void get_dis(int u, int fa) {
st[++top] = d[u];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
d[v] = d[u] + e[i].v;
get_dis(v, u);
}
}

void solve(int u, int dis, int op) {
top = 0; d[u] = dis; get_dis(u, 0);
for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]++;
for(int i = 1; i <= m; ++i) {
for(int j = 1; j <= top; ++j) if(q[i] >= st[j]) ans[i] += s[q[i] - st[j]] * op;
}
for(int i = 1; i <= top; ++i) if(st[i] <= lim) s[st[i]]--;
}

void dfs(int u) {
vis[u] = 1;
solve(u, 0, 1);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
top = 0; d[v] = e[i].v;
solve(v, e[i].v, -1);
now_sz = inf, root = 0, sz = siz[v];
find_root(v, u);
dfs(root);
}
}

int main() {
in(n), in(m);
for(int i = 1; i < n; ++i) {
int u, v, w; in(u), in(v), in(w);
ins(u, v, w), ins(v, u, w);
}
for(int i = 1; i <= m; ++i) in(q[i]);
sz = n; now_sz = inf; root = 0;
find_root(1, 1); dfs(root);
for(int i = 1; i <= m; ++i) puts(ans[i] ? "AYE" : "NAY");
}


### CF161D Distance in Tree

#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define il inline

namespace io {

#define out(a) write(a)
#define outn(a) out(a), putchar('\n')

#define I_int ll
I_int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = x * 10 + c - '0';
c = getchar();
}
return x * f;
}
char F[200];
inline void write(I_int x) {
if (x == 0) return (void) (putchar('0'));
I_int tmp = x > 0 ? x : -x;
if (x < 0) putchar('-');
int cnt = 0;
while (tmp > 0) {
F[cnt++] = tmp % 10 + '0';
tmp /= 10;
}
while (cnt > 0) putchar(F[--cnt]);
}
#undef I_int

}
using namespace io;

using namespace std;

#define N 100010

int n, k;
struct edge {
int to, nxt;
}e[N<<1];

void ins(int u, int v) {
}

int siz[N], now_sz = inf, root, sz;
void find_root(int u, int fa) {
siz[u] = 1; int res = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
find_root(v, u);
siz[u] += siz[v];
res = max(res, siz[v]);
}
res = max(res, sz - siz[u]);
if(res < now_sz) now_sz = res, root = u;
}

int top, st[N], s[N];
void get_dis(int u, int fa) {
st[++top] = d[u]; if(d[u] <= k) ++s[d[u]];
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fa || vis[v]) continue;
d[v] = d[u] + 1;
get_dis(v, u);
}
}

ll solve(int u, int dis) {
d[u] = dis; top = 0; get_dis(u, 0);
ll ans = 0;
for(int i = 1; i <= top; ++i)
if(st[i] <= k) {
if(st[i] * 2 == k) ans += 1ll * s[st[i]] * (s[st[i]] - 1) / 2ll;
else ans += 1ll * s[k - st[i]] * s[st[i]];
s[st[i]] = s[k - st[i]] = 0;
}
return ans;
}

ll ans = 0;
void dfs(int u) {
vis[u] = 1; ans += solve(u, 0);
int totsiz = sz;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
ans -= solve(v, 1);
sz = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
now_sz = inf; root = 0;
find_root(v, 0);
dfs(root);
}
}

int main() {
in(n), in(k);
for(int i = 1; i < n; ++i) {