Min-25筛学习笔记

\(Min-25\)筛可以快速求解形如\(\sum\limits_{i=1}^nf(i)\)的式子。但要求\(f\)满足:

  • 是积性函数
  • \(f(p)(p\in{P})\)是一个低阶多项式
  • \(f(p^k)\)能快速求出

我们设\(mp(i)\)表示\(i\)的最小质因子,\(p_i\)表示第\(i\)个质数。

再令\(g(n,j)=\sum\limits_{i=1}^n[i\in{P}\lor{mp(i)>p_j}]f'(i)\)\(f'(i)\)表示\(f(i)\)这个低次多项式的某一项),考虑如何从\(g(n,j-1)\)转移:要减去所有\(mp(i)=p_j\)的合数的贡献。

所以有:

\[g(n,j)=g(n,j-1)-f'(p_j)*(g(\frac{n}{p_j},j-1)-g(p_{j-1},j-1)) \]

(把\(f'\)提出来是因为积性函数的性质,后面的是要把多减掉的质数加回来)

注意到\(g(p_{j-1},j-1)\)就是前\(j-1\)个质数的\(f'(i)\)之和,设为\(sp(j-1)\)

\(\leq{n}\)\(p\)\(x\)个,那么\(g(n,x)\)就是所有质数的\(f'(p)\)之和。

\(s(n,k)=\sum\limits_{i=1}^n[mp(i)>p_k]f(i)\),答案就是\(s(n,0)\)

因为\(s\)可以由\(i>p_k\)的质数或者\(mp(i)>p_k\)的合数产生贡献,所以可以把\(s\)\(g\)联系在一起:

\[s(n,k)=g(n,x)-sp_k+\sum\limits_{p_i^j\leq{n}\land{i>k}}f(p_i^j)*(s(\frac{n}{p_i^j},i)+[j>1]) \]

(前面是质数,后面是合数。合数部分,首先枚举\(>p_k\)的最小质因子\(p_i\)\(+[j>1]\)是因为只有一个质因子的数(\(p_i^j\))也要算上,但\(j=1\)时就是质数,不能算重。注意此处是\(f\),不是\(f'\)。)

递归求解\(s(n,0)\)即可。

要注意以下几点:

  • \(g\)可以使用滚动数组优化
  • 因为求\(g\)时后面的部分是\(\frac{n}{p_j}\),又\(\left\lfloor\frac{\lfloor\frac na\rfloor}{b}\right\rfloor=\lfloor\frac{n}{ab}\rfloor\),所以我们没必要算出来所有的\(n\),只需要算出可以写成\(\lfloor\frac{n}{x}\rfloor\)这种形式的数,这样的数一共有\(O(\sqrt n)\)个。可以用两个数组\(id1\)\(id2\)分别存\(x\leq\sqrt{n}\)\(x\)\(x>\sqrt{n}\)\(\frac{n}{x}\),下标就只会到\(\sqrt{n}\)
  • \(1\)既不是质数也不是合数,不含任何一个质因子,所以\(g\)\(s\)都没有包含\(1\),只需要最后加\(1\)就行了。
  • \(f(p)\)是一个低次多项式,为了保证积性函数的性质,需要把它的每一项拆开,分别来计算。如\(f(p^k)=p^k(p^k-1)\)时,\(f(p)=p(p-1)\),那么我们分开计算\(f'(i)=i^2\)\(f'(i)=i\)的。

时间复杂度是\(O(\frac{n^{3/4}}{\log n})\)\(O(n^{1-\epsilon})\)(不会证明)

下面是\(f(p^k)=p^k(p^k-1)\)的代码:

Code

#include <bits/stdc++.h>

using namespace std;

#define ll unsigned long long

const ll mod = 1e9 + 7, inv6 = (mod + 1) / 6ll;

int lim, tot, nw, vis[200005], pr[80005];

ll n, s1[200005], s2[200005], p[200005], g1[200005], g2[200005], id1[200005], id2[200005];

ll read()
{
	ll x = 0, fl = 1; char ch = getchar();
	while (ch < '0' || ch > '9') { if (ch == '-') fl = -1; ch = getchar();}
	while (ch >= '0' && ch <= '9') {x = x * 10ll + ch - '0'; ch = getchar();}
	return x * fl;
}

void init()
{
	for (int i = 2; i <= lim; i ++ )
	{
		if (!vis[i])
		{
			vis[i] = i;
			pr[ ++ tot] = i;
		}
		for (int j = 1; j <= tot; j ++ )
		{
			if (i * pr[j] > lim || pr[j] > vis[i]) break;
			vis[i * pr[j]] = pr[j];
		}
	}
	for (int i = 1; i <= tot; i ++ )
	{
		s1[i] = (s1[i - 1] + (ll)pr[i]) % mod;
		s2[i] = (s2[i - 1] + 1ll * pr[i] * pr[i] % mod) % mod;
	}
	return;
}

ll c1(ll x)
{
	x %= mod;
	return x * (x + 1) / 2ll % mod;
}

ll c2(ll x)
{
	x %= mod;
	return x * (x + 1) % mod * (2ll * x % mod + 1ll) % mod * inv6 % mod;
}

int gt(ll x)
{
	if (x <= lim) return id1[x];
	return id2[n / x];
}

ll Min_25(ll x, int k)
{
	if (pr[k] > x) return 0;
	ll cnt = ((g2[gt(x)] - g1[gt(x)] + mod) % mod - (s2[k] - s1[k] + mod) % mod + mod) % mod;
	for (int i = k + 1; i <= tot && 1ll * pr[i] * pr[i] <= x; i ++ )
		for (ll o = 1, p0 = (ll)pr[i]; p0 <= x; p0 *= (ll)pr[i], o ++ )
			cnt = (cnt + 1ll * (p0 % mod - 1) % mod * p0 % mod * (Min_25(x / p0, i) + (o != 1)) % mod + mod) % mod;
	return cnt;
}

int main()
{
	n = read(), lim = (int)((ll)sqrt(n));
	init();
	for (ll l = 1, r; l <= n; l = r + 1)
	{
		r = n / (n / l), p[ ++ nw] = n / l;
		g1[nw] = c1(n / l) - 1, g2[nw] = c2(n / l) - 1;
		if (n / l <= lim) id1[n / l] = nw;
		else id2[r] = nw;
	}
	for (int i = 1; i <= tot; i ++ )
	{
		for (int j = 1; j <= nw && 1ll * pr[i] * pr[i] <= p[j]; j ++ )
		{
			g1[j] = (g1[j] - 1ll * pr[i] * (g1[gt(p[j] / (ll)pr[i])] - s1[i - 1] + mod) % mod + mod) % mod;
			g2[j] = (g2[j] - 1ll * pr[i] * pr[i] % mod * (g2[gt(p[j] / (ll)pr[i])] - s2[i - 1] + mod) % mod + mod) % mod;
		}
	}
	printf("%llu\n", Min_25(n, 0) + 1);
	return 0;
}
posted @ 2021-03-25 13:01  andysj  阅读(55)  评论(0编辑  收藏  举报