题解:qoj15303 Basic Counting Practice Problems
upd on 2025.12.17:和 @FLY_lai 讨论之后会了第二种做法的时间复杂度分析。
题意:对于一棵 \(1\) 为根的树和一个序列 \(c\),还有一个排列 \(p\)。定义 \(v_i\) 为 \(i\) 子树内仅经过 \(p_j \le p_i\) 的点 \(j\) 可到达的点的个数(包括 \(i\)),那么对于点 \(i\) 的权值就是 \(c_{v_i}\)。
现在给出一棵树和序列 \(c\),请求出 \(\forall 1\le i,j \le n,p_i=j\) 的情况下 \(i\) 的权值之和。\(n\le 700\)。
做法:
这里给出两个做法,一个是我的做法,一个是 jiangly 讲的做法,是一个非常厉害的 trick。
首先我们先考虑对于每个 \(p_i=j\),我令剩余的数 \(<j\) 的为 \(0\) 否则为 \(1\),那么如果我能确定子树内目前有多少个 \(1\) 和目前和子树根连通的 \(0\) 的个数,那么我就可以直接计算出来一个贡献。所以很自然地,我们设一个状态 \(dp_{u,x,y}\) 代表对于一个节点 \(u\),我们连通的 \(0\) 有 \(x\) 个,子树内的 \(1\) 有 \(y\) 个。转移就是先让子树上的合并上来,是一个背包状物。然后考虑 \(u\) 填 \(0/1\) 去更新即可,预处理这个 dp 复杂度是 \(O(n^4)\),计算答案因为我还要枚举 \(p_i\) 的值,所以也是 \(O(n^4)\) 的,无法通过。
那么既然都是一个背包状物了,那么我们很可以用拉格朗日插值去对预处理 dp 的部分优化,为了方便我们后面的做法,我们保留 \(y\) 这一维而差掉 \(x\),预处理复杂度变为 \(O(n^3)\)。稍微细说一个细节,因为我们有可以把 \(u\) 放置成 \(1\) 然后 \(x\) 直接置 \(0\) 的情况,所以这种情况就是直接是 \(h_{u,y}\),即子树内不考虑 \(u\),已经有 \(y\) 个 \(1\) 的方案数,不受插值的系数影响,可以预处理。
计算答案时,我们可以枚举 \(u,y\) 然后预处理出来 \(dp_{u,y}\) 和 \(c\) 这两个序列点乘加和之后的值,这样我们计算 \(p_i=j\) 的答案时就只需要枚举我子树内有 \(k\) 个比他小的,然后用 \(dp_{u,k}\) 乘上一堆组合数和阶乘的系数即可。问题在于,我们如果需要求出来 \(dp_{u,y}\) 这个序列,我们目前只有点值,做插值得到序列的话需要 \(n^2\),复杂度仍然是 \(O(n^4)\) 的。
抽象一下问题,我们现在需要解决一个问题:给出一个向量 \(\text{a}\),对其进行插值得到向量 \(\text{b}\),计算然后和另一个向量 \(\text{c}\) 点乘之后序列和,需要 \(O(n)\)。
正常做完全做不了,注意到最后点乘的序列 \(c\) 是定的,我们考虑进行一些神秘操作,比如我不进行插值,而是我先令 \(c\) 进行一个类似逆插值的操作,使得最后点乘完结果是一样的,这样就可以允许预处理 \(O(n^3)\) 即可。
然后我在这里卡了 2.5h 不知道咋做,之后用了一下 deepseek 发现,拉差其实是可以用矩阵乘法刻画的,具体来说,记范德蒙德矩阵为:
那么根据我们问题里的刻画,等于说我们会有 \(V\text{b}^T=\text{a}^T\),这一步是拉差,然后我们要令 \(S=\text{c}\text{b}^T\),带入一下,那么就会有 \(S=\text{c}V^{-1}\text{a}^T\)。因为 \(\text{c}V^{-1}\) 是一个固定的东西,所以我们可以预处理出来,这里 deepseek 给的形式的原因,我代码里是按照 \((V^{-1}\text{c}^T)^T\) 的方式写的。
然后把这个形式带回到原问题里就解决了,复杂度 \(O(n^3)\),貌似常数比较大喜得最劣解。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 1405, mod = 1e9 + 7;
int n, a[maxn][maxn];
vector<int> e[maxn];
int qpow(int x, int k, int p) {
int res = 1;
while(k) {
if(k & 1)
res = res * x % p;
x = x * x % p, k >>= 1;
}
return res;
}
int c[maxn], coef[maxn];
void prepare() {
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
a[i][j] = (i == 1 ? 1 : a[i - 1][j] * j) % mod;
for (int i = 1; i <= n; i++)
a[i][i + n] = 1;
for (int i = 1; i <= n; i++) {
int p = 0;
for (int j = i; j <= n; j++)
if(a[j][i]) {
p = j;
break;
}
swap(a[p], a[i]);
int inv = qpow(a[i][i], mod - 2, mod);
for (int j = 1; j <= n; j++) {
if(i == j)
continue;
int v = a[j][i] * inv % mod;
for (int k = 1; k <= 2 * n; k++)
a[j][k] = (a[j][k] - a[i][k] * v % mod + mod) % mod;
}
}
for (int i = 1; i <= n; i++) {
int inv = qpow(a[i][i], mod - 2, mod);
for (int j = 1; j <= 2 * n; j++)
a[i][j] = a[i][j] * inv % mod;
}
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
a[i][j] = a[i][j + n];
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
coef[i] = (coef[i] + a[i][j] * c[j]) % mod;
}
int X = 1, f[maxn / 2][maxn / 2], g[maxn / 2][maxn / 2], h[maxn / 2][maxn / 2], jc[maxn], sz[maxn], t[maxn];
void dfs(int u, int fa) {
for (int i = 0; i < e[u].size(); i++) {
int v = e[u][i];
if(v == fa)
continue;
dfs(v, u);
}
sz[u] = 0; f[u][0] = 1;
for (int i = 0; i < e[u].size(); i++) {
int v = e[u][i];
if(v == fa)
continue;
for (int j = 0; j <= sz[u]; j++)
t[j] = f[u][j];
for (int j = 0; j <= sz[u] + sz[v]; j++)
f[u][j] = 0;
for (int x = 0; x <= sz[u]; x++)
for (int y = 0; y <= sz[v]; y++)
f[u][x + y] = (f[u][x + y] + t[x] * f[v][y]) % mod;
sz[u] += sz[v];
}
for (int i = 0; i <= sz[u]; i++)
g[u][i] = f[u][i];
for (int i = 0; i <= sz[u] + 1; i++)
t[i] = 0;
for (int i = 0; i <= sz[u]; i++)
t[i] = (t[i] + f[u][i]) % mod,
t[i + 1] = (t[i + 1] + f[u][i]) % mod;
for (int i = 0; i <= sz[u] + 1; i++)
f[u][i] = t[i];
sz[u]++;
}
void redfs(int u, int fa) {
for (int i = 0; i < e[u].size(); i++) {
int v = e[u][i];
if(v == fa)
continue;
redfs(v, u);
}
sz[u] = 0; f[u][0] = 1;
for (int i = 0; i < e[u].size(); i++) {
int v = e[u][i];
if(v == fa)
continue;
for (int j = 0; j <= sz[u]; j++)
t[j] = f[u][j];
for (int j = 0; j <= sz[u] + sz[v]; j++)
f[u][j] = 0;
for (int x = 0; x <= sz[u]; x++)
for (int y = 0; y <= sz[v]; y++)
f[u][x + y] = (f[u][x + y] + t[x] * f[v][y]) % mod;
sz[u] += sz[v];
}
for (int i = 0; i <= sz[u]; i++)
h[u][i] = (h[u][i] + f[u][i] * coef[X] % mod) % mod;
for (int i = 0; i <= sz[u] + 1; i++)
t[i] = 0;
for (int i = 0; i <= sz[u]; i++)
t[i] = (t[i] + f[u][i] * X) % mod,
t[i + 1] = (t[i + 1] + g[u][i]) % mod;
for (int i = 0; i <= sz[u]; i++)
f[u][i] = t[i];
sz[u]++;
}
int C[maxn][maxn];
signed main() {
cin >> n;
for (int i = 1; i <= n; i++)
cin >> c[i];
prepare();
for (int i = 1; i < n; i++) {
int x, y; cin >> x >> y;
e[x].push_back(y);
e[y].push_back(x);
}
jc[0] = 1;
for (int i = 1; i <= n; i++)
jc[i] = jc[i - 1] * i % mod;
dfs(1, 0);
for (int i = 1; i <= n; i++) {
X = i;
redfs(1, 0);
}
C[0][0] = 1;
for (int i = 1; i <= n; i++) {
C[i][0] = 1;
for (int j = 1; j <= i; j++)
C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % mod;
}
for (int i = 1; i <= n; i++) {
sz[i]--;
for (int j = 1; j <= n; j++) {
int ans = 0;
for (int k = 0; k <= sz[i]; k++)
ans = (ans + C[n - j][k] * jc[k] % mod * C[j - 1][sz[i] - k] % mod * jc[sz[i] - k] % mod * jc[n - sz[i] - 1] % mod * h[i][k] % mod) % mod;
cout << ans << " ";
}
cout << endl;
}
return 0;
}
接下来是 jiangly 讲的做法,非常厉害的一个 trick,类似的题还有 abc311h。
就是我们考虑,答案的形态一定是一个连通的 \(0\) 还有一堆 \(1\) 给他封死了,所以我们考虑记录 \(0\) 有 \(x\) 个,封死的 \(1\) 有 \(y\) 个,这样记录 dp 状态,最后计算答案的时候我们做一个类似递推的计算方式,我们记 \(g_{x,y}\) 代表节点 \(u\) 子树内的答案,整棵树里用了 \(x\) 个 \(0\) 和 \(y\) 个 \(1\),那么答案计算有几种,第一种是从 \(dp_{u,x,y}\) 来的,第二种是我一个没确定节点选择了 \(0/1\),按上述过程转移即可,复杂度 \(n^3\),重点是怎么计算 dp 值。
考虑一个类似于 dfn 序上 dp 的东西,我们在 dfs 的时候同时传入一个 dp 数组,依赖于这个数组进行 dp,转移分为两种,第一种是我这个点选 \(1\),那么就直接跳过整个子树的转移,否则就向下选即可,同时传入目前的 dp 值数组。注意这里的 dp 数组是前面的儿子做完的 dp 数组会传上来给后面的儿子用的,可能有点意识流,不太懂就可以直接把他当按 dfn 序 dp 即可。
但是这样做只解决了节点 \(1\) 的问题可以做到 \(O(n^3)\),做不了所有点的。我们考虑对这个东西做启发式,对于一条重链上的点,我们直接遍历到底部,优先遍历重子,然后从下往上做,同时处理一条重链的答案,其他的我们就只做 dp,但是因为他传入时带了别的地方的 dp 值所以就不能做答案。把所有的重链按链顶深度全部处理了即可。复杂度看上去是带 log 的,实际严格分析是不带的,我们来稍微证明一下。
首先一个点作为重链节点被处理到的复杂度很容易分析是三次的,只用考虑作为一个轻儿子被多次遍历到的情况。考虑一个点在第一次重链处理时被遍历到的顺序是 \(s\),会花费 \(O(s^2)\) 的代价处理。那么考虑第二次重链处理,因为我这个是第一次重链的轻子数,而这个 \(s\) 一定遍历完了重子树,所以这里第二次处理的时候 \(s\) 至少折半,所以总的对于一个节点的复杂度是 \(O(\sum\limits_{i=0} (\frac{s}{2^i}) ^ 2) = O(s^2)\)。那么总的复杂度加和就是 \(O(n^3)\)。

浙公网安备 33010602011771号