Codeforces 1111E DP + 树状数组 + LCA + dfs序
题意:给你一颗树,有q次询问,每次询问给你若干个点,这些点可以最多分出m组,每组要满足两个条件:1:每组至少一个点,2:组内的点不能是组内其它点的祖先,问这样的分组能有多少个?
思路:https://blog.csdn.net/BUAA_Alchemist/article/details/86765501
代码:
#include <bits/stdc++.h>
#define LL long long
#define lowbit(x) (x & (-x))
using namespace std;
const LL mod = 1000000007;
const int maxn = 100010;
vector<int> G[maxn];
vector<int> a;
int dfn[maxn], sz[maxn], tot, t;
LL dp[maxn][310];
int n;
void add(int x, int y) {
G[x].push_back(y);
G[y].push_back(x);
}
void dfs(int x, int fa) {
dfn[x] = ++tot;
sz[x] = 1;
for (auto y : G[x]) {
if(y == fa) continue;
dfs(y, x);
sz[x] += sz[y];
}
}
queue<int> q;
int dep[maxn], f[maxn][20];
void bfs() {
q.push(1);
dep[1] = 1;
while(q.size()) {
int x = q.front();
q.pop();
for (auto y : G[x]) {
if(dep[y]) continue;
dep[y] = dep[x] + 1;
//dis[y] = dis[x] + 1;
f[y][0] = x;
for (int j = 1; j <= t; j++)
f[y][j] = f[f[y][j - 1]][j - 1];
q.push(y);
}
}
}
int lca(int x, int y) {
if(dep[x] > dep[y]) swap(x, y);
for (int i = t; i >= 0; i--)
if(dep[f[y][i]] >= dep[x]) y = f[y][i];
if(x == y) return y;
for (int i = t; i >= 0; i--)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
struct BIT {
int c[maxn];
int ask(int x) {
int ans = 0;
for(; x; x -= lowbit(x)) ans += c[x];
return ans;
}
void add(int x, int y) {
for(; x <= n; x += lowbit(x)) c[x] += y;
}
};
BIT tr;
int h[maxn], vis[maxn];
int main() {
int u, v, T;
scanf("%d%d", &n, &T);
t = (int)(log(n) / log(2)) + 1;
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
add(u, v);
}
dfs(1, -1);
bfs();
int k, m, r, x;
LL ans = 0;
while(T--) {
scanf("%d%d%d",&k, &m, &r);
ans = 0;
for (int i = 1; i <= k; i++) {
scanf("%d", &x);
vis[x] = 1;
a.push_back(x);
tr.add(dfn[x], 1);
tr.add(dfn[x] + sz[x], -1);
}
for (int i = 0; i < k; i++) {
int LCA = lca(a[i], r);
h[i + 1] = tr.ask(dfn[a[i]]) + tr.ask(dfn[r]) - 2 * tr.ask(dfn[LCA]) + vis[LCA] - 1;
}
sort(h + 1, h + 1 + k);
dp[0][0] = 1;
for (int i = 1; i <= k; i++)
for (int j = 0; j <= min(i, m); j++) {
if(j > 0)
dp[i][j] = (LL)((LL)dp[i - 1][j - 1] + ((LL)dp[i - 1][j] * max(0, j - h[i])) % mod) % mod;
}
for (int i = 1; i <= m; i++)
ans = (ans + dp[k][i]) % mod;
printf("%lld\n", ans);
for (int i = 0; i < k; i++) {
tr.add(dfn[a[i]], -1);
tr.add(dfn[a[i]] + sz[a[i]], 1);
vis[a[i]] = 0;
}
a.clear();
}
}

浙公网安备 33010602011771号