CF771C
提供一个不需要换根的树形 \(\text{dp}\) 做法。
假如只有一次询问,那么答案为树上两点间距离除以 \(k\) 向上取整,那么很自然地想到能否直接求树上所有路径长度和,然后除以 \(k\) 向上取整?显然是不行的,因为每条路径长除以 \(k\) 的余数合并后可能错误地减少贡献。于是我们考虑将路径长除以 \(k\) 的整数部分和余数部分分开计算。
观察数据范围,发现 \(1\le k \le 5\),可以将路径长按除以 \(k\) 的余数分为 \(5\) 类。设 \(dp_x\) 表示以 \(x\) 为根的子树的答案。设 \(f_x\) 表示以 \(x\) 为根的子树中以 \(x\) 为其中一端点的所有链长除以 \(k\) 向下取整的和。设 \(g_{x,i}\) 表示以 \(x\) 为根的子树中以 \(x\) 为其中一端点且长度模 \(k\) 余 \(i\) 的链的数量。
考虑如何转移:设当前节点为 \(x\),它其中一个子节点为 \(y\)。
首先考虑如何转移 \(f_x\) 与 \(g_{x,i}\)。当从 \(y\) 点走到 \(x\) 点时,所有以 \(y\) 为较浅端点往下的链长度会增加 \(1\)。于是可以得到如下转移:
然后考虑如何转移得到 \(dp_x\),分为这样几部分:每个子树内部的答案 \(dp_y\),两个子树间余数部分的合并,每一个子树的 \(f_y\) 产生的贡献和。第二者通过 \(g_{y,i}\) 得到,具体实现可以记录一个所有 \(y\) 左兄弟的 \(g\) 数组的前缀和,可做到不重不漏。而最后一部分可以考虑以 \(y\) 为根的子树中的链与所有子树外的点一一匹配,产生的贡献为 \(f_y \times (s_x-s_y)\),其中 \(s_x\) 为以 \(x\) 为根的子树大小。
时间复杂度 \(O(nk^2)\)
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll SIZE = 200005;
const ll mod = 998244353;
ll n, K;
ll head[SIZE], ver[SIZE*2], nxt[SIZE*2], tot;
ll dp[SIZE], f[SIZE], g[SIZE][5], siz[SIZE];
inline ll rd(){
ll f = 1, x = 0;
char ch = getchar();
while(ch < '0' || ch > '9'){
if(ch == '-') f = -1;
ch = getchar();
}
while(ch >= '0' && ch <= '9'){
x = (x << 1) + (x << 3) + (ch ^ 48);
ch = getchar();
}
return f*x;
}
void add(ll x, ll y){
ver[++tot] = y, nxt[tot] = head[x];
head[x] = tot;
}
void dfs(ll x, ll fa){
ll jl[5]; memset(jl, 0, sizeof(jl));
g[x][0] = 1; siz[x] = 1;
for(int i = head[x]; i; i = nxt[i]){
ll y = ver[i];
if(y == fa) continue;
dfs(y, x);
dp[x] += dp[y]; siz[x] += siz[y];
for(int j = 0; j < K; j++){
dp[x] += ceil((double)(j+1) / (double)K) * g[y][j];
}
for(int j = 0; j < K; j++){
for(ll k = 0; k < K; k++){
dp[x] += ceil((double)(j+k+2) / (double)K) * jl[j] * g[y][k];
}
}
for(int j = 0; j < K; j++) jl[j] += g[y][j];
}
for(int i = head[x]; i; i = nxt[i]){
ll y = ver[i];
if(y == fa) continue;
for(int j = 0; j < K; j++) g[x][(j+1)%K] += g[y][j];
dp[x] += (siz[x] - siz[y]) * f[y];
f[x] += f[y] + g[y][K-1];
}
}
int main(){
n = rd(), K = rd();
for(int i = 1; i < n; i++){
ll x = rd(), y = rd();
add(x, y); add(y, x);
}
dfs(1, 0);
printf("%lld", dp[1]);
return 0;
}