多彩的树
多彩的树
有一棵树包含 N 个节点,节点编号从 \(1\) 到 \(N\)。节点总共有 \(K\) 种颜色,颜色编号从 \(1\) 到 \(K\)。第 \(i\) 个节点的颜色为 \(A_i\)。
\(Fi\) 表示恰好包含 \(i\) 种颜色的路径数量。请计算:\((\sum_{r=1}^n*(F_i * 131^i)) mod (10^9 + 7)\)
思路:
首先我们发现树形dp解决不了这个问题,因为状态的设置为 \(dp[N][1<<k]\) ,而在转移的过程时间复杂度完全过不去,于是乎得另辟蹊径。我们发现同种颜色数量能由不同的颜色组合构成,而这种组合之间又不会互相干扰。我们可以把所有颜色组合搭配给表示出来。选择 \(1、3、5\) 号颜色为例, 其二进制状态可以表示为10101。我们就可以以 \(1、3、5\) ,这三种颜色的点为起点枚举出所有只包含这三种颜色的连通块。 \(1、3、5\) 这种组合方式的路径数只可能在这些连通块中产生,接下来的问题就是如何求出\(1、3、5\),这种组合方式的路径数。我们要用到容斥原理, 我们可以很容易求出这个联通块的所有路径数(不一定全都包含三种颜色),为 \(cnt * (cnt - 1) / 2 + cnt\) ,求出来之后我们只需要减去不包含三种颜色的路径数,这时候我们要枚举出所有可能\(10100、00101、10001、10000、00100、00001\)。那么很显然答案出来了\(dp[10101] = cnt * (cnt - 1) / 2 + cnt - dp[10100] - dp[00101] - dp[10001] - dp[10000] - dp[00100] - dp[00001]\)。接下来的难点来到了如何枚举出所有不合法的子情况,这里不加以证明,给出实现方式。定义 \(i\)为当前几种颜色的表示状态,那么其不合法的子状态表示为
for (int j = (i - 1) & i; j; j = (j - 1) & i)
代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = 5e4 + 10;
const int mod = 1e9 + 7;
int dp[1 << 11], vis[N];
int a[N], n, m, cnt, ct[1 << 11], k, fac[20];
vector<int>G[N];
void dfs(int u, int fa, int s)
{
if ((s | (1 << (a[u] - 1))) != s || vis[u]) return ;
vis[u]++, cnt++;
for (auto v: G[u])
{
if (v == fa) continue;
dfs(v, u, s);
}
}
int vaild(int i)
{
int tot = 0;
while(i)
{
if (i & 1) tot++;
i >>= 1;
}
return tot;
}
void init()
{
fac[0] = 1;
for (int i = 1; i <= 18; i++)
{
fac[i] = 1ll * fac[i - 1] * 131 % mod;
}
for (int i = 1; i <= (1 << k) - 1; i++)
{
ct[i] = vaild(i);
}
}
signed main()
{
// std::ios::sync_with_stdio(false);
// cin.tie(NULL);cout.tie(NULL);
cin >> n >> k;
init();
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
} //建树
for (int i = 1; i <= (1 << k) - 1; i++)
{
memset(vis, 0, sizeof(vis)); //每次都要记录,以防止重复计算
for (int j = 1; j <= n; j++)
{
cnt = 0;
if (!vis[j] && (i & (1 << (a[j] - 1))))
{
dfs(j, -1, i);
}
dp[i] = (dp[i] + cnt + cnt * (cnt - 1) / 2) % mod;
}
}
int ans = 0;
for (int i = 1; i <= (1 << k) - 1; i++)
{
for (int j = (i - 1) & i; j; j = (j - 1) & i) //枚举所有不合法路径
{
dp[i] -= dp[j];
}
ans = (ans + 1ll * fac[ct[i]] * dp[i] % mod) % mod;
}
cout << (ans + mod) % mod <<"\n";
}

浙公网安备 33010602011771号