BZOJ4805 - 欧拉函数求和

Portal

Description

给出\(n(n\leq2\times10^9)\),求\(\sum_{i=1}^n \varphi(i)\)

Solution

杜教筛。
杜教筛的作用就是以一个低于\(O(n)\)(准确来说是\(O(n^{\frac{2}{3}})\))的时间复杂度来计算积性函数\(f\)的前缀和。
\(S(x)=\sum_{i=1}^x f(i)\)。那么对于任意一个积性函数\(g\),我们有

\[\begin{align*} \sum_{i=1}^n (f\times g)(i) &= \sum_{d=1}^n \sum_{d|i} g(d)f(\frac{n}{d}) \\ &= \sum_{d=1}^n g(d) \sum_{i=1}^{\lfloor \frac{n}{d} \rfloor} f(i) \\ &= \sum_{d=1}^n g(d)S(\lfloor \frac{n}{d} \rfloor) \\ &= g(1)S(n) + \sum_{d=2}^n g(d)S(\lfloor \frac{n}{d} \rfloor) \\ S(n) &= \sum_{i=1}^n (f\times g)(i) - \sum_{d=2}^n g(d)S(\lfloor \frac{n}{d} \rfloor) \end{align*} \]

我们可以用整除分块来做后面的部分,那么如果我们能让前面的部分好算,就解决了。对于本题,因为有\(\varphi \times 1=id\),所以令\(g(x)=1(x)\),就有\(\sum_{i=1}^n (\varphi \times 1)(i)=\sum_{i=1}^n i=\dfrac{n(n+1)}{2}\)

根据我并不会算的复杂度,我们可以预处理出$x\leq n^{\frac{2}{3}}$内的所有$S(x)$,此时计算$S(n)$的复杂度也为$O(n^{\frac{2}{3}})$。

我们发现我们会计算到\(S(⌊\dfrac{n}{2}⌋),S(⌊\dfrac{⌊\frac{n}{2}⌋}{2}⌋=⌊\dfrac{n}{4}⌋),S(⌊\dfrac{n}{3}⌋)...\)这么一类数,即除了预处理的\(S(1..n^{\frac{2}{3}})\),还要计算\(S(⌊\dfrac{n}{1..n^{\frac{1}{3}}}⌋)\)\(n^{\frac{1}{3}}\)个数。记录\(S_1(x)=S(⌊\dfrac{n}{x}⌋)\),就可以进行记忆化搜索啦。

Code

//欧拉函数求和
#include <cstdio>
typedef long long lint;
const int N1=2e6;
int n0,n1;
int prCnt,pr[N1]; bool prNot[N1];
int phi[N1]; lint S[N1],S1[N1];
void init(int n)
{
    phi[1]=1;
    for(int i=2;i<=n;i++)
    {
        if(!prNot[i])
        {
            pr[++prCnt]=i; phi[i]=i-1;
            for(lint j=1LL*i*i;j<=n;j*=i) phi[j]=phi[j/i]*(i-1);
        }
        for(int j=1;j<=prCnt;j++)
        {
            int x=i*pr[j]; if(x>n) break;
            prNot[x]=true;
            if(i%pr[j]) phi[x]=phi[i]*(pr[j]-1);
            else {phi[x]=phi[i]*pr[j]; break;}
        }
    }
    for(int i=1;i<=n;i++) S[i]=S[i-1]+phi[i];
}
lint sum(int n)
{
    if(n<=n1) return S[n];
    if(S1[n0/n]) return S1[n0/n];
    lint r=n*(n+1LL)/2;
    for(int L=2,R;L<=n;L=R+1)
    {
        int v=n/L; R=n/v;
        r-=(R-L+1)*sum(v);
    }
    return S1[n0/n]=r;
}
int main()
{
    scanf("%d",&n0);
    for(n1=1;n1*n1*n1<=n0;n1++); n1*=n1;
    init(n1);
    printf("%lld\n",sum(n0));
    return 0;
}
posted @ 2018-05-08 10:03  VisJiao  阅读(340)  评论(0编辑  收藏  举报