[JSOI2018] 潜入行动
题意
link
在 \(n(\leq 10^5)\) 个点选 \(k(\leq 100)\) 个点, 被选出的点周围的点不包括自己都会被覆盖,求所有点都被覆盖的方案数。
树形背包
其实就是简单的树形背包,分类讨论清楚就行了。
状态
\(f[u][k][1/0][1/0]\) 表示节点 \(u\), 已经选了 \(k\) 个点,现在 \(u\) 是否被覆盖,\(u\) 是否被选, 并且除了根外所有节点都被覆盖。
初始状态
对于一个点的树,\(f[u][0][0][0] = f[u][1][0][1] = 1\)。
转移
考虑两个树合并,每个结点最开始是一个点的树, 以下的转移都满足状态。
\(f[u][k][0][0] = f[u][i][0][0] * f[v][j][1][0]\)
\(f[u][k][0][1] = f[u][i][0][1] * (f[v][j][0][0] + f[v][j][1][0])\)
\(f[u][k][1][0] = f[u][i][1][0] * (f[v][j][1][0] + f[v][j][1][1]) + f[u][i][0][0] * f[v][j][1][1]\)
\(f[u][k][1][1] = f[u][i][0][1] * (f[v][j][0][1] + f[v][j][1][1]) + f[u][i][1][1] * (f[v][j][0][0] + f[v][j][0][1] + f[v][j][1][0] + f[v][j][0][1])\)
最后的答案即是 \(f[1][k][1][0] + f[1][k][1][1]\)
分析
树形背包如果转移的次数是 两个树的大小的乘积,那么就有整体时间复杂度 \(O(nk)\)。
空间也是 \(O(nk)\)。
代码
代码是好久之前的,状态后两维是反过来的。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int mod = 1000000007;
const int MAXN = 100010;
int n, m;
int a[MAXN];
struct edge{
int next, point;
}e[MAXN << 1];
int first[MAXN];
void add(int u, int v){
static int cnt = 0;
++cnt;
e[cnt].point = v;
e[cnt].next = first[u];
first[u] = cnt;
}
int f[MAXN][103][2][2], siz[MAXN], tmp[210][2][2];
int Add(int a, int b){
int s = a + b;
while (s > mod) s -= mod;
return s;
}
void dp(int u, int fa){
siz[u] = 1;
f[u][0][0][0] = f[u][1][1][0] = 1;
for(int i = first[u]; i; i = e[i].next){
int v = e[i].point;
if(v == fa) continue;
dp(v, u);
for(int j = 0; j <= min(siz[u], m); j++)
for(int x = 0; x < 2; x++)
for(int y = 0; y < 2; y++)
tmp[j][x][y] = f[u][j][x][y],
f[u][j][x][y] = 0;
for(int j = 0; j <= min(siz[u], m); j++)
for(int k = 0; k <= min(siz[v], m - j); k++)
{
f[u][j + k][0][0] = Add(f[u][j + k][0][0], (ll)tmp[j][0][0] * f[v][k][0][1] % mod);
f[u][j + k][0][1] = Add(f[u][j + k][0][1], Add((ll)tmp[j][0][0] * f[v][k][1][1] % mod,(ll)tmp[j][0][1] * Add(f[v][k][0][1], f[v][k][1][1]) % mod));
f[u][j + k][1][0] = Add(f[u][j + k][1][0], (ll)tmp[j][1][0] * Add(f[v][k][0][0], f[v][k][0][1]) % mod);
f[u][j + k][1][1] = Add(f[u][j + k][1][1], Add((ll)tmp[j][1][1] * Add(Add(f[v][k][0][0], f[v][k][0][1]), Add(f[v][k][1][0], f[v][k][1][1])) % mod, (ll)tmp[j][1][0] * Add(f[v][k][1][0], f[v][k][1][1]) % mod));
}
siz[u] += siz[v];
}
}
int main(){
scanf("%d%d", &n, &m);
for(int i = 1; i < n; i++){
int u, v;
scanf("%d%d", &u, &v);
add(u, v);
add(v, u);
}
dp(1, 1);
printf("%d\n", (f[1][m][0][1] + f[1][m][1][1]) % mod);
return 0;
}

浙公网安备 33010602011771号