[HEOI2013]SAO题解
[HEOI2013]SAO
标签
拓扑排序 + 计数 + dp + 树形dp
思路
题意是给一个树状有向图,求有多少种拓扑序。
\(\star\) 设 \(f[i][j]\) 表示在 \(i\) 的子树中,\(i\) 的拓扑序为 \(j\) 的方案数。
那么考虑转移,合并 \(u\) 和 \(u\) 的子树 \(v\)。分两种情况。
- \(u\) 指向 \(v\)
\(f[u][p3] = \sum_{p1 = 1}^{siz_x} \sum_{p2=1}^{siz_y} f[u][p1] * f[v][p2] * \tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1} [p1 \leq p3 \leq p2 + p1 - 1]\)
\(p1 \leq p3\) 因为原来拓扑序在 \(p1\) 左侧的点,合并后拓扑序一定还在 \(p1\) 左侧, 因为原来拓扑序在 \(p1\) 右侧的点,合并后拓扑序一定还在 \(p1\) 右侧
\(p3 \leq p2 + p1 - 1\) , 因为合并后 \(v\) 的拓扑序一定比 \(u\) 大,所以 \([p2,siz_v]\) 一定在 \(u\) 右侧, \([1,p2]\) 既可能在 \(u\) 左,又可能在 \(u\) 右,所以 \(p3_{max} = p2 + p1 - 1\)
合并后,如果 \(u\) 的拓扑序是 \(p3\) , 那么左边的 \(p3 - 1\) 个点中,有 \(p1 - 1\) 个是原来的,右边的 \(siz_u+siz_v - p3\) 个点中,有 \(siz_u - p1\) 个点是原来的,所以要乘上 \(\tbinom{p3-1}{p1-1} * \tbinom{siz_u+siz_v-p3}{siz_u-p1}\)
但是这样转移是 \(O(n^3)\) 的,考虑优化。
\(\sum_{p2=1}^{p3 - p1 + 1} f[v][p2]\) 可以用前缀和优化.
这样复杂度就是 \(O(n^2)\).
-
\(v\) 指向 \(u\)
和 1 差不多。
最后 \(ans = \sum_{i=1}^{n}f[1][i]\)
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#define ll long long
using namespace std;
const int N = 1010;
const ll mod = 1e9 + 7;
int T;
int n;
int head[N], nxt[2 * N], to[2 * N], op[2 * N], e_tot;
int siz[N];
ll C[N][N];
ll f[N][N], g[N], s[N][N];
void link(int x, int y, int z)
{
nxt[++e_tot] = head[x], head[x] = e_tot, to[e_tot] = y, op[e_tot] = z;
}
ll add(ll x, ll y)
{
return x + y >= mod ? x + y - mod : x + y;
}
ll suf(ll x, ll y)
{
return x - y < 0 ? x - y + mod : x - y;
}
void dfs(int u, int _fa)
{
siz[u] = 1, f[u][1] = 1;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (v == _fa) continue;
dfs(v, u);
memcpy(g, f[u], sizeof(f[u]));
memset(f[u], 0, sizeof(f[u]));
if (op[i] == 1)
{
for (int p1 = 1; p1 <= siz[u]; ++p1)
{
for (int p3 = p1; p3 < p1 + siz[v]; ++p3)
{
f[u][p3] = add(f[u][p3], C[p3 - 1][p1 - 1] * C[siz[u] + siz[v] - p3][siz[u] - p1] % mod * g[p1] % mod * suf(s[v][siz[v]], s[v][p3 - p1]) % mod);
}
}
}
else
{
for (int p1 = 1; p1 <= siz[u]; ++p1)
{
for (int p3 = p1 + 1; p3 <= p1 + siz[v]; ++p3)
{
f[u][p3] = add(f[u][p3], C[p3 - 1][p1 - 1] * C[siz[u] + siz[v] - p3][siz[u] - p1] % mod * g[p1] % mod * s[v][p3 - p1] % mod);
}
}
}
siz[u] += siz[v];
}
for (int i = 1; i <= siz[u]; ++i)
{
s[u][i] = add(s[u][i - 1], f[u][i]);
}
}
int main()
{
scanf("%d", &T);
for (int i = 0; i <= 1005; ++i) C[i][0] = 1;
for (int i = 1; i <= 1005; ++i)
{
for (int j = 1; j <= i; ++j)
{
C[i][j] = add(C[i - 1][j - 1], C[i - 1][j]);
}
}
while (T--)
{
scanf("%d", &n);
memset(f, 0, sizeof(f));
memset(head, 0, sizeof(head));
memset(nxt, 0, sizeof(nxt));
memset(to, 0, sizeof(to));
memset(op, 0, sizeof(op));
e_tot = 0;
for (int i = 1; i < n; ++i)
{
int x, y;
char c;
scanf("%d %c %d", &x, &c, &y);
++x, ++y;
link(x, y, c == '<');
link(y, x, c == '>');
}
dfs(1, 0);
cout << s[1][n] << endl;
}
return 0;
}

浙公网安备 33010602011771号