题解[省选联考 2020 A 卷] 作业题
反演+矩阵树
首先题目要求的是 \(\sum\limits_{T}\sum\limits_{i=1}^{n-1}w_{e_{i}}\times gcd(w_{e_1},...,w_{e_{n-1}})\)
很明显的可以用反演,也可以直接套 \(\phi*1=id\)
那么就可以得出 \(\sum\limits_{d=1}^{max(w)}\phi(d)\sum\limits_{T}^{d|(gcd(w_{e\in T}))}\sum\limits_{i=1}^{n-1}w_{e_i}\)
后面部分需要用到矩阵树定理,但矩阵树求得是 \(\sum\limits_T\prod\limits_{i=1}^{n-1}w_{e_i}\),和上面的不一样,所以需要稍作转换
考虑将边权用一个一次函数 \(wx+1\) 表示,在模 \(x^2\) 下作乘法,那么你会发现一次项系数就是边权和
所以不妨定义一个多项式四则运算,加减直接对应加减就行了
乘法 \((ax+b)(cx+d)=(ad+bc)x+bd\)
除法,考虑 \((cx+d)\) 的逆元,即我们需要求 \((Ax+B)(cx+d)\equiv 1(mod~x^2)\)
\(Bd=1\Rightarrow B=\frac{1}{d}\)
\(Adx+Bcx=0\Rightarrow A=-\frac{c}{d^2}\)
那么就可以得到 \(\frac{ax+b}{cx+d}=(ax+b)(-\frac{c}{d^2}x+\frac{1}{d})=\frac{ad-bc}{d^2}x+\frac{b}{d}\)
直接高斯消元即可,时间复杂度 \(O(n^3max(w))\),有点勉强,这里有一个优化,即加边大于等于 \(n-1\) 才进行矩阵树,这样就是 \(O(144n^4)\)
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define MOD (998244353)
using namespace std;
typedef long long ll;
ll qmi(ll a, ll b);
int n, m;
struct num
{
ll x, y;
num operator + (const num a) const
{return {(x + a.x) % MOD, (y + a.y) % MOD};}
num operator - (const num a) const
{return {(x - a.x + MOD) % MOD, (y - a.y + MOD) % MOD};}
num operator * (const num a) const
{return {(x * a.y + y * a.x) % MOD, y * a.y % MOD};}
num operator / (const num a) const
{
ll inv = qmi(a.y, MOD - 2);
return {((x * a.y - y * a.x) % MOD + MOD) * inv % MOD * inv % MOD, y * inv % MOD};
}
} a[35][35];
struct edge
{
int u, v, w;
} e[1005];
ll qmi(ll a, ll b)
{
ll res = 1;
while (b)
{
if (b & 1)
res = res * a % MOD;
a = a * a % MOD;
b >>= 1;
}
return res;
}
int gauss()
{
num res = num({0, 1});
int w = 1;
for (int i = 1; i < n; i++)
{
for (int j = i + 1; j < n; j++)
{
if (a[j][i].y)
{
swap(a[j], a[i]), w = -w;
break;
}
}
num inv = num({0, 1}) / a[i][i];
for (int j = i + 1; j < n; j++)
{
num d = a[j][i] * inv;
for (int k = i; k < n; k++)
a[j][k] = a[j][k] - a[i][k] * d;
}
}
for (int i = 1; i < n; i++)
res = res * a[i][i];
return w > 0 ? res.x : (num({0, 0}) - res).x;
}
int phi(int x)
{
int res = x;
for (int i = 2; i <= x; i++)
{
if (x % i == 0)
res -= res / i;
while (x % i == 0)
x /= i;
}
if (x > 1)
res -= res / x;
return res;
}
int main()
{
int mx = 0, ans = 0;
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; i++)
{
scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);
mx = max(mx, e[i].w);
}
for (int i = 1; i <= mx; i++)
{
memset(a, 0, sizeof(a));
int tot = 0;
for (int j = 1; j <= m; j++)
{
if (e[j].w % i)
continue;
tot++;
int u = e[j].u, v = e[j].v, w = e[j].w;
num P = num({w, 1});
a[u][u] = a[u][u] + P, a[v][v] = a[v][v] + P;
a[u][v] = a[u][v] - P, a[v][u] = a[v][u] - P;
}
if (tot < n - 1)
continue;
int t = gauss();
ans = (ans + (ll)phi(i)* t % MOD) % MOD;
}
printf("%d", ans);
return 0;
}