题解[国家集训队]Crash的数字表格 / JZPTAB
可以知道,题目中所求的是
\[\sum\limits_{i=1}^n\sum\limits_{j=1}^mlcm(i,j)=\sum\limits_{i=1}^n\sum\limits_{j=1}^m\frac{ij}{gcd(i,j)}
\]
枚举所有约数
\[\sum\limits_{i=1}^n\sum\limits_{j=1}^m\frac{ij}{gcd(i,j)}=\sum\limits_{d=1}^{min(n,m)}\sum\limits_{i=1}^n\sum\limits_{j=1}^m\frac{ij}{d}[gcd(i,j)=d]\\=\sum\limits_{d=1}^{min(n,m)}d\sum\limits_{i=1}^{\lfloor\frac{n}{d}\rfloor}\sum\limits_{j=1}^{\lfloor\frac{m}{d}\rfloor}ij[gcd(i,j)=1]
\]
令 \(sum(n,m)=\sum\limits_{i=1}^n\sum\limits_{j=1}^mij[gcd(i,j)=1]\\f(k)=\sum\limits_{i=1}^n\sum\limits_{j=1}^mij[gcd(i,j)=k],F(k)=\sum\limits_{i=1}^n\sum\limits_{j=1}^mij[k|gcd(i,j)]\)
\[\Rightarrow F(k)=\sum\limits_{k|d}f(d)\Rightarrow f(k)=\sum\limits_{k|d}\mu(\lfloor\frac{d}{k}\rfloor)F(d)\\=\sum\limits_{k|d}\mu(\lfloor\frac{d}{k}\rfloor)\sum\limits_{i=1}^n\sum\limits_{j=1}^mij[d|gcd(i,j)]=\sum\limits_{k|d}\mu(\lfloor\frac{d}{k}\rfloor)d^2\sum\limits_{i=1}^{\lfloor\frac{n}{d}\rfloor}\sum\limits_{j=1}^{\lfloor\frac{m}{d}\rfloor}ij
\]
令 \(g(i,j)=\frac{i\times(i+1)}{2}\times\frac{j\times(j+1)}{2}\)
\[\Rightarrow f(k)=\sum\limits_{k|d}\mu(\lfloor\frac{d}{k}\rfloor)d^2g(\lfloor\frac{n}{d}\rfloor,\lfloor\frac{m}{d}\rfloor)
\]
\[\Rightarrow sum(n,m)=f(1)=\sum\limits_{d=1}^{min(n,m)}\mu(d)d^2g(\lfloor\frac{n}{d}\rfloor,\lfloor\frac{m}{d}\rfloor)
\]
\[ans=\sum\limits_{d=1}^{min(n,m)}dsum(\lfloor\frac{n}{d}\rfloor,\lfloor\frac{m}{d}\rfloor)
\]
推到这就很明显可以用整除分块了,时间复杂度 \(O(n)\),注意取模和long long
代码
#include<iostream>
#include<cstdio>
#define MOD (20101009)
#define g(x, y) ((((ll)(x + 1) * (x) / 2 % MOD) * ((ll)(y + 1) * (y) / 2 % MOD)) % MOD)
using namespace std;
const int N = 1e7 + 5;
typedef long long ll;
bool st[N];
int prime[N], mu[N];
ll s[N];
void init()
{
int tot = 0;
mu[1] = 1;
for (int i = 2; i < N; i++)
{
if (!st[i])
prime[++tot] = i, mu[i] = -1;
for (int j = 1; i * prime[j] < N; j++)
{
st[i * prime[j]] = true;
if (i % prime[j] == 0)
break;
mu[i * prime[j]] = -mu[i];
}
}
for (int i = 1; i < N; i++)
s[i] = ((ll)s[i - 1] + (ll)mu[i] * i * i + MOD) % MOD;
}
ll sum(int n, int m)
{
int k = min(n, m);
ll res = 0;
for (int l = 1, r; l <= k; l = r + 1)
{
r = min(k, min(n / (n / l), m / (m / l)));
res = (res + (ll)(s[r] - s[l - 1] + MOD) * g(n / l, m / l) % MOD) % MOD;
}
return res;
}
int main()
{
init();
int n, m;
scanf("%d%d", &n, &m);
int k = min(n, m);
ll res = 0;
for (int l = 1, r; l <= k; l = r + 1)
{
r = min(k, min(n / (n / l), m / (m / l)));
res += (ll)(l + r) * (r - l + 1) / 2 * sum(n / l, m / l);
res %= MOD;
}
printf("%lld", res);
return 0;
}