peiwenjun's blog 没有知识的荒原

P9135 [THUPC 2023 初赛] 快速 LCM 变换 题解

题目描述

给定一个长为 \(n\) 的序列 \(r\) ,对 \(\forall 1\le i\lt j\le n\) ,删除 \(r_i,r_j\) ,再加入 \(r_i+r_j\)

求得到的 \(\frac{n(n-1)}2\) 个数列的最小公倍数之和,对 \(998244353\) 取模。

数据范围

  • \(2\le n\le 5\cdot 10^5,1\le r_i\le 10^6\)

时间限制 \(\texttt{2s}\) ,空间限制 \(\texttt{512MB}\)

分析

\(\forall p\in\text{prime}\) ,记 \(a(p),b(p)\) 为原数组中含 \(p\) 的幂次最大、次大值。

那么原序列的 \(\text{lcm}\)\(M=\prod\limits_{p\in\text{prime}}p^{a(p)}\) ,考虑以 \(M\) 为基准研究操作对 \(\text{lcm}\) 的影响。

\(\texttt{Key observation}\) :删除 \(r_i,r_j\) ,再加入 \(r_i+r_j\)\(\text{lcm}\)\(p\) 幂次的增量为:

\[\max(v_p(r_i+r_j)-a,0)-[v_p(r_i)=a](a-b)-[v_p(r_j)=a](a-b) \]

证明:不妨 \(v_p(r_i)\ge v_p(r_j)\)

  • 如果 \(v_p(r_i+r_j)\ge a\) ,那么 \(v_p(r_i)=v_p(r_j)\) ,因此后面的减法不会产生贡献。
  • 如果 \(v_p(r_i)=a,v_p(r_i+r_j)<a\) ,那么 \(v_p(r_j)\lt b\) ,增量为 \(b-a\)
  • 如果 \(v_p(r_i)\lt a\) ,显然增量为零。

本题最巧妙的地方在于,上述结论仅对两个数的情形成立

对于 \(\ge 3\) 个数,增量不能直接拆成关于 \(r_1,\cdots,r_k,\sum r_k\) 完全独立的变量。

\(f(i)=\prod\limits_{p\in\text{prime}}(\frac 1p)^{[v_p(r_i)=a](a-b)},g(i)=\prod\limits_{p\in\text{prime}}p^{\max(v_p(i)-a,0)}\) ,则目标变为:

\[\sum_{1\le i<j\le n}f(r_i)f(r_j)g(r_i+r_j)\\ =\frac12\bigg(\sum_{i=1}^n\sum_{j=1}^nf(r_i)f(r_j)g(r_i+r_j)-\sum_{i=1}^nf(r_i)^2g(2\cdot r_i)\bigg) \]

前者只需对值域做卷积,将 \(f(r_i)f(r_j)\) 贡献到 \(r_i+r_j\) 的位置即可统计答案。

时间复杂度 \(\mathcal O(n\sqrt V+V\log V)\)

#include<bits/stdc++.h>
using namespace std;
const int v=2e6,maxn=1<<21,mod=998244353,inv2=(mod+1)/2,inv3=(mod+1)/3;
int n,x,res;
int a[maxn],b[maxn],r[maxn];
int f[maxn],g[maxn],h[maxn],pw[25],cnt[maxn];
bool vis[maxn];
inline int qpow(int a,int k)
{
    int res=1;
    while(k)
    {
        if(k&1) res=1ll*res*a%mod;
        a=1ll*a*a%mod,k>>=1;
    }
    return res;
}
inline int add(int x,int y)
{
    if((x+=y)>=mod) x-=mod;
    return x;
}
inline int dec(int x,int y)
{
    if((x-=y)<0) x+=mod;
    return x;
}
inline void ntt(int *a,int n,int op)
{
    for(int i=0;i<n;i++) if(i<r[i]) swap(a[i],a[r[i]]);
    for(int k=2,m=1;k<=n;k<<=1,m<<=1)
    {
        int x=qpow(op==1?3:inv3,(mod-1)/k);
        for(int i=0;i<n;i+=k)
            for(int j=i,w=1;j<i+m;j++)
            {
                int v=1ll*a[j+m]*w%mod;
                a[j+m]=dec(a[j],v),a[j]=add(a[j],v);
                w=1ll*w*x%mod;
            }
    }
    if(op==-1)
    {
        int inv=qpow(n,mod-2);
        for(int i=0;i<n;i++) a[i]=1ll*a[i]*inv%mod;
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&r[i]);
        int x=r[i];
        for(int j=2;j*j<=x;j++)
        {
            if(x%j) continue;
            int cnt=0;
            while(x%j==0) x/=j,cnt++;
            if(a[j]<=cnt) b[j]=a[j],a[j]=cnt;
            else b[j]=max(b[j],cnt);
        }
        if(x!=1)
        {
            if(a[x]<=1) b[x]=a[x],a[x]=1;
            else b[x]=max(b[x],1);
        }
    }
    for(int i=1;i<=v;i++) f[i]=g[i]=1;
    for(int p=2;p<=v;p++)
    {
        if(vis[p]) continue;
        int inv=qpow(p,mod-1-(a[p]-b[p]));
        for(long long i=0,x=1;x<=v;i++,x*=p) pw[i]=x;
        for(int i=p;i<=v;i+=p)
        {
            vis[i]=1,cnt[i]=cnt[i/p]+1;
            if(cnt[i]==a[p]) f[i]=1ll*f[i]*inv%mod;
            if(cnt[i]>a[p]) g[i]=1ll*g[i]*pw[cnt[i]-a[p]]%mod;
        }
        for(int i=p;i<=v;i+=p) cnt[i]=0;
    }
    for(int i=1;i<=n;i++)
    {
        int x=r[i];
        res=(res-1ll*f[x]*f[x]%mod*g[2*x])%mod;
        h[x]=add(h[x],f[x]);
    }
    n=1<<21;
    for(int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);
    ntt(h,n,1);
    for(int i=0;i<n;i++) h[i]=1ll*h[i]*h[i]%mod;
    ntt(h,n,-1);
    for(int i=0;i<n;i++) res=(res+1ll*h[i]*g[i])%mod;
    res=1ll*(res+mod)*inv2%mod;
    for(int p=2;p<=v;p++) if(a[p]) res=1ll*res*qpow(p,a[p])%mod;
    printf("%d\n",res);
    return 0;
}

posted on 2023-04-30 21:09  peiwenjun  阅读(15)  评论(0)    收藏  举报

导航