P6803 [CEOI 2020] 星际迷航 题解
以下 DP 状态大写字母为根,小写字母为子树内。
设 \(f_u\) 为 \(u\) 点开始往子树内走,是否胜利。
为了换根 DP,设 \(g_u\) 为 \(u\) 儿子中 \(f_v = 0\) 的个数。
考虑 \(D = 1\) 的情况,枚举加边两端 \((u, v)\),显然只有 \(F_v = 0\) 时才有影响,假设有 \(k\) 个 \(F_v = 0\),那么我们还需要计算有多少 \(u\) 连接了 0 之后会反转根的 \(f\)。
为了计算这个,设 DP \(h_u\) 表示 \(u\) 子树内有多少点连 0 之后改变 \(u\) 的状态,\(h_u = [g_u = 0]\sum_v [f_v = 0]h_v + [g_u = 1](1 + \sum_vh_v)\)。
我们把 \(g_u = 0, 1\) 的式子分别设为 \(h0_u, h1_u\),这样就可以换根出 \(H\)。
接下来继续考虑 \(D = 1\),显然答案就是 \([f_1 = 1]((n - k)n + (n - H_1)k) + [f_1 = 0]H_1k\)。
考虑拓展到 \(D > 1\),那么设 \(c0, c1\) 表示 \([i,D]\) 部分的胜负态,枚举当前树的根可以转移 \(c0, c1\)。
使用矩阵乘法计算出 \([1, D]\) 的部分,然后再套用 \(D = 1\) 的公式即可。
#include <iostream>
#include <cstring>
#include <queue>
#include <algorithm>
#define int long long
using namespace std;
const int N = 2e5 + 10, mod = 1e9 + 7;
int n, d, g[N], G[N], f[N], F[N], h[N], H[N], h0[N], h1[N], H0[N], H1[N];
vector<int> grp[N];
void dfs(int u, int fa) {
h0[u] = 1;
for(auto v : grp[u]) {
if(v == fa) continue;
dfs(v, u);
g[u] += f[v] == 0;
h1[u] += (f[v] == 0 ? h[v] : 0);
h0[u] += h[v];
}
f[u] = g[u] > 0;
h[u] = (g[u] == 1 ? h1[u] : (g[u] == 0 ? h0[u] : 0));
}
void dfs2(int u, int fa) {
for(auto v : grp[u]) {
if(v == fa) continue;
int gu = G[u] - (f[v] == 0), fu = gu > 0, h0u = H0[u] - h[v], h1u = H1[u] - (f[v] == 0 ? h[v] : 0), hu = (gu == 1 ? h1u : (gu == 0 ? h0u : 0));
G[v] = (g[v] + (fu == 0)), F[v] = G[v] > 0;
H0[v] = h0[v] + hu;
H1[v] = h1[v] + (fu == 0 ? hu : 0);
H[v] = (G[v] == 1 ? H1[v] : (G[v] == 0 ? H0[v] : 0));
dfs2(v, u);
}
}
struct Mat {
int m[2][2], r, c;
void clear(int R, int C) {
r = R, c = C;
m[0][0] = m[0][1] = m[1][0] = m[1][1] = 0;
}
void init() {
m[0][0] = m[1][1] = 1;
}
Mat operator * (const Mat &W) const {
Mat res; res.clear(r, W.c);
for(int k = 0; k < c; k ++)
for(int i = 0; i < r; i ++)
for(int j = 0; j < W.c; j ++)
res.m[i][j] = (res.m[i][j] + m[i][k] * W.m[k][j]) % mod;
return res;
}
} M;
Mat qmi(Mat a, int b) {
Mat res; res.clear(a.r, a.c), res.init();
while(b) {
if(b & 1) res = res * a;
a = a * a, b >>= 1;
}
return res;
}
signed main() {
ios::sync_with_stdio(0), cin.tie(0);
cin >> n >> d;
for(int i = 1, a, b; i < n; i ++)
cin >> a >> b, grp[a].push_back(b), grp[b].push_back(a);
dfs(1, 0);
G[1] = g[1], F[1] = f[1], H[1] = h[1], H0[1] = h0[1], H1[1] = h1[1];
dfs2(1, 0);
int ans = 0, c0 = 0, c1 = 0;
for(int i = 1; i <= n; i ++) if(F[i]) c1 ++; else c0 ++;
M.clear(2, 2);
for(int i = 1; i <= n; i ++) {
if(!F[i]) M.m[1][0] += n, M.m[0][0] += n - H[i], M.m[0][1] += H[i];
else M.m[0][0] += H[i], M.m[1][1] += n, M.m[0][1] += n - H[i];
}
M.m[0][0] %= mod, M.m[0][1] %= mod, M.m[1][0] %= mod, M.m[1][1] %= mod;
M = qmi(M, d - 1);
int cc0 = (c0 * M.m[0][0] + c1 * M.m[1][0]) % mod;
int cc1 = (c0 * M.m[0][1] + c1 * M.m[1][1]) % mod;
if(F[1]) cout << (cc1 * n + (n - H[1]) * cc0) % mod << '\n';
else cout << H[1] * cc0 % mod << '\n';
return 0;
}

QwQ
浙公网安备 33010602011771号