P5900 无标号无根树计数 题解

不懂为啥都要对原式神秘转化之后再牛顿迭代,直接对原式牛顿迭代即可!完全不用转化!

设无标号有根树的组合类是 \(\mathcal T\),则有 \(\mathcal T=\mathcal Z\times\mathrm{MSET}(\mathcal T)\),即 \(T(x)=x\exp\sum\limits_{j\ge1}\dfrac{T(x^j)}j\)

\(G(F(x))=F(x)-x\exp\sum\limits_{j\ge1}\dfrac{F(x^j)}j=0\),要求 \(F(x)\bmod x^n\),问题变为牛顿迭代的形式。

首先钦定 \(F(x)\bmod x^1=[x^0]F(x)=0\)(否则 \(\exp\sum\limits_{j\ge1}\dfrac{F(x^j)}j\) 的常数项不收敛),

然后假设已经求出模 \(x^{\frac n2}\) 意义下的解 \(F_0(x)\),则模 \(x^n\) 意义下的解 \(F(x)\equiv F_0(x)-\dfrac{G(F_0(x))}{G'(F_0(x))}\pmod{x^n}\)

考虑如何求 \(G'(F_0(x))\)

观察到 \(G(F_0(x))=F_0(x)-x\exp F_0(x)\exp\sum\limits_{j\ge2}\dfrac{F_0(x^j)}j\)

设最终答案是 \(H(x)\)(这里 \(H(x)\) 是与 \(F_0(x)\) 无关的常量),

\(\forall j\ge2\),有 \(F_0(x^j)\equiv H(x^j)\pmod{x^n}\),则 \(\exp\sum\limits_{j\ge2}\dfrac{F_0(x^j)}j\equiv\exp\sum\limits_{j\ge2}\dfrac{H(x^j)}j\pmod{x^n}\)

于是 \(\exp\sum\limits_{j\ge2}\dfrac{F_0(x^j)}j\) 是与 \(F_0(x)\) 无关的常量,

\(G'(F_0(x))=1-x\exp F_0(x)\exp\sum\limits_{j\ge2}\dfrac{F_0(x^j)}j=1-x\exp\sum\limits_{j\ge1}\dfrac{F_0(x^j)}j\)

于是 \(F(x)\equiv F_0(x)-\dfrac{G(F_0(x))}{G'(F_0(x))}\equiv F_0(x)-\dfrac{F_0(x)-x\exp\sum\limits_{j\ge1}\dfrac{F_0(x^j)}j}{1-x\exp\sum\limits_{j\ge1}\dfrac{F_0(x^j)}j}\pmod{x^n}\)

现在我们会算无标号有根树了,设 \(f_i\) 表示 \(i\) 个点的无标号有根树个数,考虑怎么算无标号无根树。

\(n\) 为奇数时,每个无标号无根树都有唯一的重心,所以只需统计根为重心的无标号有根树个数,

考虑用 \(f_n\) 减去根不为重心的无标号有根树个数,这些树一定有一棵大小超过 \(\left\lfloor\dfrac n2\right\rfloor\) 的子树,

枚举这棵子树的大小,则答案为 \(f_n-\sum\limits_{i=\left\lfloor\frac n2\right\rfloor+1}^{n-1}f_i\times f_{n-i}\)

\(n\) 为偶数时,考虑用同样的计数方法,

发现有两个重心的树,即有一棵大小为 \(\dfrac n2\) 的子树的树,可能会被重复统计,

设这棵大小为 \(\dfrac n2\) 的子树为 \(T\),若原树去掉 \(T\) 后剩下的树与 \(T\) 完全一致,

则原树分别以两个重心为根形成的有根树完全一致,这样的树不会被重复统计,

反之若原树去掉 \(T\) 后剩下的树与 \(T\) 不同,这样的树就会被重复统计,

所以从大小为 \(\dfrac n2\) 的有根树中任选出两棵,就可以组合出一种被重复统计的树,

所以有 \(f_{\frac n2}\choose2\) 种被重复统计的树,答案减去 \(f_{\frac n2}\choose2\) 即可。

没有封装多项式类,所以代码比较混乱邪恶,建议谨慎阅读。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
#define G 3
#define _G 332748118
#define M 998244353
using namespace std;
int n, l, r[6000050], f[6000050], g[6000050], h[6000050], x[6000050], y[6000050], z[6000050], k[6000050], v[6000050];
int P(int x, int y)
{
    int q = 1;
    for (; y; y >>= 1, x = x * x % M)
        if (y & 1)
            q = q * x % M;
    return q;
}
void F(int *f, int n, int v)
{
    for (int i = 0; i < n; ++i)
        if (i < r[i])
            swap(f[i], f[r[i]]);
    for (int L = 2, m; L <= n; L <<= 1)
    {
        m = L >> 1;
        int W = P(v == 1 ? G : _G, (M - 1) / L);
        for (int l = 0, r = L - 1; r <= n; l += L, r += L)
        {
            int o = 1;
            for (int p = l; p < l + m; ++p)
            {
                int x = f[p], y = f[p + m];
                f[p] = (x + o * y) % M, f[p + m] = (x + M - o * y % M) % M;
                o = o * W % M;
            }
        }
    }
}
void I(int *f, int *g, int n)
{
    memset(g, 0, n << 4);
    memset(x, 0, n << 4);
    g[0] = P(f[0], M - 2);
    int L;
    for (L = 4;; L <<= 1)
    {
        memcpy(x, f, L << 2);
        memcpy(y, g, L << 3);
        l = __lg(L);
        for (int i = 0; i < L; ++i)
            r[i] = r[i >> 1] >> 1 | (i & 1) << l - 1;
        F(x, L, 1);
        F(y, L, 1);
        for (int i = 0; i < L; ++i)
            x[i] = x[i] * y[i] % M;
        F(x, L, -1);
        int _ = P(L, M - 2);
        for (int i = 0; i < L; ++i)
            x[i] = (M - x[i] * _ % M) % M;
        x[0] = (x[0] + 2) % M;
        memset(x + (L >> 1), 0, L << 2);
        F(g, L, 1);
        F(x, L, 1);
        for (int i = 0; i < L; ++i)
            g[i] = g[i] * x[i] % M;
        F(g, L, -1);
        for (int i = 0; i < L; ++i)
            g[i] = g[i] * _ % M;
        if (L >> 1 >= n)
            break;
    }
    memset(g + n, 0, L - n << 3);
}
void LN(int *f, int *g, int n)
{
    memset(h, 0, n << 4);
    for (int i = 0; i < n - 1; ++i)
        h[i] = (i + 1) * f[i + 1] % M;
    I(f, g, n);
    int L = 1;
    while (L >> 1 < n)
        L <<= 1;
    l = __lg(L);
    for (int i = 0; i < L; ++i)
        r[i] = r[i >> 1] >> 1 | (i & 1) << l - 1;
    F(g, L, 1);
    F(h, L, 1);
    for (int i = 0; i < L; ++i)
        h[i] = g[i] * h[i] % M;
    F(h, L, -1);
    int _ = P(L, M - 2);
    for (int i = 0; i < L; ++i)
        h[i] = h[i] * _ % M;
    g[0] = 0;
    for (int i = 1; i < n; ++i)
        g[i] = h[i - 1] * P(i, M - 2) % M;
    memset(g + n, 0, L - n << 3);
}
void EXP(int *f, int *g, int n)
{
    memset(g, 0, n << 4);
    g[0] = 1;
    int L;
    for (L = 4;; L <<= 1)
    {
        LN(g, z, L >> 1);
        for (int i = 0; i < L >> 1; ++i)
            z[i] = (f[i] + M - z[i]) % M;
        z[0] = (z[0] + 1) % M;
        l = __lg(L);
        for (int i = 0; i < L; ++i)
            r[i] = r[i >> 1] >> 1 | (i & 1) << l - 1;
        F(g, L, 1);
        F(z, L, 1);
        for (int i = 0; i < L; ++i)
            g[i] = g[i] * z[i] % M;
        F(g, L, -1);
        int _ = P(L, M - 2);
        for (int i = 0; i < L; ++i)
            g[i] = g[i] * _ % M;
        memset(g + (L >> 1), 0, L << 2);
        if (L >> 1 >= n)
            break;
    }
    memset(g + n, 0, L - n << 3);
}
signed main()
{
    v[1] = 1;
    for (int i = 2; i <= 6e6; ++i)
        v[i] = (M - M / i) * v[M % i] % M;
    scanf("%lld", &n), ++n;
    for (int L = 4;; L <<= 1)
    {
        memset(g, 0, L << 2);
        for (int i = 1; i < L >> 2; ++i)
            for (int j = i; j < L >> 1; j += i)
                g[j] = (g[j] + f[i] * i % M * v[j]) % M;
        EXP(g, k, L >> 1);
        for (int i = (L >> 1) - 1; i >= 1; --i)
            k[i] = (M - k[i - 1]) % M;
        k[0] = 1;
        I(k, g, L >> 1);
        k[0] = 0;
        for (int i = 1; i < L >> 2; ++i)
            k[i] = (k[i] + f[i]) % M;
        F(g, L, 1);
        F(k, L, 1);
        for (int i = 0; i < L; ++i)
            g[i] = g[i] * k[i] % M;
        F(g, L, -1);
        int _ = P(L, M - 2);
        for (int i = 0; i < L; ++i)
            g[i] = g[i] * _ % M;
        for (int i = 0; i < L >> 1; ++i)
            f[i] = (f[i] + M - g[i]) % M;
        if (L >> 1 >= n)
            break;
    }
    --n;
    int q = f[n];
    for (int i = (n >> 1) + 1; i < n; ++i)
        q = (q + M - f[i] * f[n - i] % M) % M;
    if (!(n & 1))
    {
        int u = f[n >> 1];
        q = (q + M - u * (u - 1) % M * (M + 1 >> 1) % M) % M;
    }
    printf("%lld", q);
    return 0;
}
posted @ 2024-04-24 15:04  5k_sync_closer  阅读(22)  评论(1编辑  收藏  举报