题解[省选联考 2020 A 卷] 作业题

反演+矩阵树

首先题目要求的是 \(\sum\limits_{T}\sum\limits_{i=1}^{n-1}w_{e_{i}}\times gcd(w_{e_1},...,w_{e_{n-1}})\)

很明显的可以用反演,也可以直接套 \(\phi*1=id\)

那么就可以得出 \(\sum\limits_{d=1}^{max(w)}\phi(d)\sum\limits_{T}^{d|(gcd(w_{e\in T}))}\sum\limits_{i=1}^{n-1}w_{e_i}\)

后面部分需要用到矩阵树定理,但矩阵树求得是 \(\sum\limits_T\prod\limits_{i=1}^{n-1}w_{e_i}\),和上面的不一样,所以需要稍作转换

考虑将边权用一个一次函数 \(wx+1\) 表示,在模 \(x^2\) 下作乘法,那么你会发现一次项系数就是边权和

所以不妨定义一个多项式四则运算,加减直接对应加减就行了

乘法 \((ax+b)(cx+d)=(ad+bc)x+bd\)

除法,考虑 \((cx+d)\) 的逆元,即我们需要求 \((Ax+B)(cx+d)\equiv 1(mod~x^2)\)

\(Bd=1\Rightarrow B=\frac{1}{d}\)

\(Adx+Bcx=0\Rightarrow A=-\frac{c}{d^2}\)

那么就可以得到 \(\frac{ax+b}{cx+d}=(ax+b)(-\frac{c}{d^2}x+\frac{1}{d})=\frac{ad-bc}{d^2}x+\frac{b}{d}\)

直接高斯消元即可,时间复杂度 \(O(n^3max(w))\),有点勉强,这里有一个优化,即加边大于等于 \(n-1\) 才进行矩阵树,这样就是 \(O(144n^4)\)

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#define MOD (998244353)
using namespace std;
typedef long long ll;

ll qmi(ll a, ll b);
int n, m;
struct num
{
    ll x, y;
    num operator + (const num a) const
    {return {(x + a.x) % MOD, (y + a.y) % MOD};}
    num operator - (const num a) const
    {return {(x - a.x + MOD) % MOD, (y - a.y + MOD) % MOD};}
    num operator * (const num a) const
    {return {(x * a.y + y * a.x) % MOD, y * a.y % MOD};}
    num operator / (const num a) const
    {
        ll inv = qmi(a.y, MOD - 2);
        return {((x * a.y - y * a.x) % MOD + MOD) * inv % MOD * inv % MOD, y * inv % MOD};
    }
} a[35][35];
struct edge
{
    int u, v, w;
} e[1005];

ll qmi(ll a, ll b)
{
    ll res = 1;
    while (b)
    {
        if (b & 1)
            res = res * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return res;
}

int gauss()
{
    num res = num({0, 1});
    int w = 1;
    for (int i = 1; i < n; i++)
    {
        for (int j = i + 1; j < n; j++)
        {
            if (a[j][i].y)
            {
                swap(a[j], a[i]), w = -w;
                break;
            }
        }
        num inv = num({0, 1}) / a[i][i];
        for (int j = i + 1; j < n; j++)
        {
            num d = a[j][i] * inv;
            for (int k = i; k < n; k++)
                a[j][k] = a[j][k] - a[i][k] * d;
        }
    }
    for (int i = 1; i < n; i++)
        res = res * a[i][i];
    return w > 0 ? res.x : (num({0, 0}) - res).x;
}

int phi(int x)
{
    int res = x;
    for (int i = 2; i <= x; i++)
    {
        if (x % i == 0)
            res -= res / i;
        while (x % i == 0)
            x /= i;
    }
    if (x > 1)
        res -= res / x;
    return res;
}

int main()
{
    int mx = 0, ans = 0;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= m; i++)
    {
        scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);
        mx = max(mx, e[i].w);
    }
    for (int i = 1; i <= mx; i++)
    {
        memset(a, 0, sizeof(a));
        int tot = 0;
        for (int j = 1; j <= m; j++)
        {
            if (e[j].w % i)
                continue;
            tot++;
            int u = e[j].u, v = e[j].v, w = e[j].w;
            num P = num({w, 1});
            a[u][u] = a[u][u] + P, a[v][v] = a[v][v] + P;
            a[u][v] = a[u][v] - P, a[v][u] = a[v][u] - P;
        }
        if (tot < n - 1)
            continue;
        int t = gauss();
        ans = (ans + (ll)phi(i)* t % MOD) % MOD;
    }
    printf("%d", ans);
    return 0;
}
posted @ 2021-05-10 20:03  DSHUAIB  阅读(59)  评论(0编辑  收藏  举报