CF1613F - Tree Coloring 题解
大家好,这里是一个不会 NTT 的菜鸡在 xjbbb(怎么说话呢,骂两个老师高兴啊?)。NTT 的板子都是网上剽的(需要注意的是,NTT 需要将度数变成 \(2\) 的整次幂,但是 INTT 之后一定要 resize 回 \(\deg a + \deg b - 1\),不然可能会指数级增长)。
朴素的方法
前面的部分感觉并不难想。考虑正难则反,计算至少有一个点 \(x\) 使得 \(a_x = a_{fa_x} + 1\)。考虑容斥,钦定某些边满足,形成若干条直链(一个点必然不会有两条通向儿子的边满足),每条直链分配的 \(a\) 值事连续的区间。那么方案数就是 \(c!\),其中 \(c\) 事直链数量,因为这些完整的区间的排列顺序唯一决定了它们分配的 \(a\) 值。而显然 \(c = n - i\),其中 \(i\) 事选的边的数量。
于是现在就是要求对于每个 \(i\),选出 \(i\) 条边,满足每个节点最多有一条通向儿子的边被选,的方案数。容易发现,这其实就是 \(b_x = |son_x|\) 中选 \(i\) 的权值积的和。这貌似是分治 NTT 经典问题(?)。对每个 \(b_j\),其选或不选的生成函数为 \(1 + b_jx\),最后答案就是 \(\prod(1 + b_jx)\) 的各项系数。这玩意可以 cdq 分治 + NTT,复杂度 2log。(当然也可以任意顺序启发式合并 + NTT)
code
constexpr int N = 1e6 + 10;
fc_init(N);
int n;
int deg[N];
struct poly : vi {
using vi::vi;
static constexpr int g = 3;
void NTT(int f) {
poly &a = *this;
int lim = 0;
while((1 << lim) < a.size()) ++lim;
a.resize(1 << lim, 0);
static int R[N];
REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
for(int i = 1; i < a.size(); i <<= 1) {
int gn = qpow(g, (mod - 1) / (i << 1));
for(int j = 0; j < a.size(); j += i << 1){
int G = 1;
for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
}
}
}
if(f == 1) return;
int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
}
friend poly operator*(poly x, poly y) {
int sz = x.size() + y.size() - 1;
x.resize(sz, 0), y.resize(sz, 0);
x.NTT(1), y.NTT(1);
REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
x.NTT(-1);
x.resize(sz);
return x;
}
void prt() {
for(int x : *this) cout << x << " "; puts("!");
}
};
poly cdq(int l = 1, int r = n) {
if(l == r) return {1, deg[l]};
int mid = l + r >> 1;
return cdq(l, mid) * cdq(mid + 1, r);
}
void mian() {
n = read();
memset(deg, -1, sizeof(deg)); deg[1] = 0;
REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
poly p = cdq();
int ans = 0;
REP(i, 0, n - 1) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
prt(ans), pc('\n');
}
1log 方法
来自 EI 的方法。
注意到一件事情:\(\sum b_j = \mathrm O(n)\)。将所有 \(b_j\) 放到桶里去,对每个桶事 \((1 + jx)^{b_j}\),可以直接用二项式定理线性展开。然后再一路暴力 NTT 的话,复杂度显然是 \(\mathrm O\!\left(\sum\limits_{i=1}^n\sum\limits_{j = 1}^iC_j\log n\right)=\mathrm O\!\left(\sum\limits_{i = 1}^n (n-i+1)C_i\log n\right)\),其中 \(C_i\) 事桶 \(i\) 的大小。事实上我们可以重新指定顺序,让 \(n-i+1\) 这个看上去事 \(\mathrm O(n)\) 的东西发挥作用。我们发现必然有 \(\sum\limits_{i=1}^niC_i=\mathrm O(n)\),如果我们使得 \(n-i\) 变成 \(i\) 的话,复杂度就是 \(\mathrm O(n\log n)\) 了,而这只需要倒过来暴力 NTT 即可。
code
constexpr int N = 1e6 + 10;
fc_init(N);
int n;
int deg[N];
struct poly : vi {
using vi::vi;
static constexpr int g = 3;
void NTT(int f) {
poly &a = *this;
int lim = 0;
while((1 << lim) < a.size()) ++lim;
if((1 << lim) > N) exit(114514);
a.resize(1 << lim, 0);
static int R[N];
REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
for(int i = 1; i < a.size(); i <<= 1) {
int gn = qpow(g, (mod - 1) / (i << 1));
for(int j = 0; j < a.size(); j += i << 1){
int G = 1;
for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
}
}
}
if(f == 1) return;
int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
}
friend poly operator*(poly x, poly y) {
int sz = x.size() + y.size() - 1;
x.resize(sz, 0), y.resize(sz, 0);
x.NTT(1), y.NTT(1);
REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
x.NTT(-1);
x.resize(sz);
return x;
}
void prt() {
for(int x : *this) cout << x << " "; puts("!");
}
};
int cnt[N];
void mian() {
n = read();
memset(deg, -1, sizeof(deg)); deg[1] = 0;
REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
poly p = {1};
REP(i, 1, n) ++cnt[deg[i]];
PER(i, n, 1) {
poly q(cnt[i] + 1);
int now = 1;
REP(j, 0, cnt[i]) q[j] = (ll)now * comb(cnt[i], j) % mod, now = (ll)now * i % mod;
p = p * q;
}
int ans = 0;
REP(i, 0, min(n - 1, int(p.size()) - 1)) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
prt(ans), pc('\n');
}
带根号的方法
跟据 \(\sum\limits_{i = 1}^n iC_i = \mathrm O(n)\) 还可以得到一个结论:\(C_i\) 有值的桶只有 \(\mathrm O(\sqrt n)\) 个,根号分治分类讨论即可证明(和固定自然想到根号分治)。
众所周知,在要相乘的多项式比较少的时候,可以都 NTT,一起乘起来,最后只要一遍 INTT。但是代价是每个多项式的 NTT 规模要是所有多项式的和,一般在多项式比较多的时候就不划算了。
这题只有 \(\mathrm O(\sqrt n)\) 个多项式,比较少,可以考虑这个 trick。那么每个多项式都要做规模为 \(n\) 的 NTT,而每个多项式都是 \((1+jx)^{b_j}\) 的形式,它的点值是好求的,就求出二项式点值然后快速幂即可。这样复杂度是 \(\mathrm O(n\sqrt n\log n)\),虽然跟暴力 NTT 复杂度一样,但是常数不要小太多!所以卡一卡实际上是可以过去的。
以及你会发现这玩意跑不满。NTT 规模实际上可以到 \(n - C_0\),这看起来没用,但你会发现这对 \(1\sim\sqrt n\) 都取满了的情况很友好,于是决定分析一波这玩意的最劣情况。设他有 \(x\) 个位置有值,此时要使得 \(C_0\) 最小,最优的是 \(2\sim x\) 各放一个,剩下放 \(1\),那么 \(C_0=\mathrm O\!\left(x^2\right)\),那么复杂度就是 \(\mathrm O\!\left((nx-x^3)\log n\right)\)。对 \(nx-x^3\) 求导,可知 \(x=\sqrt{\dfrac n3}\) 的时候最劣,算一下发现雀食优了不少。
code
constexpr int N = 1e6 + 10;
fc_init(N);
int n;
int deg[N];
struct poly : vi {
using vi::vi;
static constexpr int g = 3;
void NTT(int f) {
poly &a = *this;
int lim = 0;
while((1 << lim) < a.size()) ++lim;
if((1 << lim) > N) exit(114514);
a.resize(1 << lim, 0);
static int R[N];
REP(i, 0, a.size() - 1) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (lim - 1));
REP(i, 0, a.size() - 1) if(i < R[i]) ::swap(a[i], a[R[i]]);
for(int i = 1; i < a.size(); i <<= 1) {
int gn = qpow(g, (mod - 1) / (i << 1));
for(int j = 0; j < a.size(); j += i << 1){
int G = 1;
for(int k = 0; k < i; ++k, G = (ll)G * gn % mod) {
int x = a[j + k], y = (ll)G * a[j + k + i] % mod;
a[j + k] = add(x, y), a[j + k + i] = add(x, -y);
}
}
}
if(f == 1) return;
int nv = inv(a.size()); reverse(a.begin() + 1, a.end());
REP(i, 0, a.size() - 1) a[i] = (ll)a[i] * nv % mod;
}
friend poly operator*(poly x, poly y) {
int sz = x.size() + y.size() - 1;
x.resize(sz, 0), y.resize(sz, 0);
x.NTT(1), y.NTT(1);
REP(i, 0, x.size() - 1) x[i] = (ll)x[i] * y[i] % mod;
x.NTT(-1);
x.resize(sz);
return x;
}
void prt() {
for(int x : *this) cout << x << " "; puts("!");
}
};
int cnt[N];
void mian() {
n = read();
memset(deg, -1, sizeof(deg)); deg[1] = 0;
REP(i, 1, n - 1) { int x = read(), y = read(); ++deg[x], ++deg[y]; }
REP(i, 1, n) ++cnt[deg[i]];
int lim = 0;
while((1 << lim) < n - cnt[0] + 1) ++lim;
poly p(1 << lim, 1);
PER(i, n, 1) if(cnt[i]) {
int G = qpow(3, (mod - 1) >> lim), gn = 1;
REP(j, 0, (1 << lim) - 1) p[j] = (ll)p[j] * qpow(((ll)i * gn + 1) % mod, cnt[i]) % mod, gn = (ll)gn * G % mod;
}
p.NTT(-1);
int ans = 0;
REP(i, 0, min(n - 1, int(p.size()) - 1)) addto(ans, (i & 1 ? -1ll : 1ll) * fc[n - i] * p[i] % mod);
prt(ans), pc('\n');
}