记一道有趣的校模拟赛题

题面

分析

显然的,每个排列可以通过 \(\text{Cantor Expansion}\) 映射到一个唯一的整数,这是一个双射,满足题目要求。

\(\text{Cantor Expansion}\) 公式:

\[\begin{cases} A_i=\sum\limits_{j=i}^n[a_j<a_i] \\ f(a)=\sum\limits_{i=1}^n A_i\times(n-i)!,\text{for a sequence a} \end{cases} \]

那么现在来考虑加入 \(0\) 的情况。

显然初始时我们算出的序列值(即不含 \(0\) 的贡献)乘上 \(m!\) 即为初值,现在我们需要考虑增量。(注意,序列编号从 \(1\) 开始)

设有 \(m\) 个位置上都是 \(0\),即 \(m\) 个位置上的数都被抹去了。我们发现,对于这些 \(0\) 给出一些赋值,会影响它前面(不包含 \(0\))的 \(A_i\),它前面的 \(0\)\(A_i\) (假设已经有一些值),及它自身位置上的 \(A_i\)

就这样直接算的话,需要 \(O(nm! \log n)\),考虑优化。

不妨把每一个 \(0\) 都拆出来算。先预处理出 \(m\)\(0\) 能填的数的集合,从小到大排序,记作 \(c\)。再预处理出第 \(i\)\(0\) 所在的位置,记作 \(b_i\)。记当前第 \(j\)\(0\) 上填的数为 \(c_k\)。当我们考虑第 \(j\)\(0\) 对他前面的 \(0\)\(A_i\) 的贡献时,只需要枚举那些第 \(1\sim j-1\) 位上大于 \(c_k\) 的情况。于是有贡献:

\[\sum_{x=1}^{j-1}(n-b_x)!\times(num-2)!\times(m-k) \]

现在考虑算它前面的 \(A_i\),可以直接用求逆序对的方式算。对于前面的 \(A_i\),增量是 \((n-i)!\times (m-1)!\),很好理解。

至于它自身的 \(A_i\),这个东西倒着扫一遍,完全就是求逆序对,对于第 \(j\)\(0\),贡献式是:

\[\sum_{x=1}^m (n-i)!\times(m-1)!\times(\sum_{y=b_j}^n [c_x>a_y]) \]

这些东西都可以用 \(O(nm^2 \log n)\) 的时间复杂度求出。提取无关项出来,预处理一些前缀和可以做到 \(O(nm \log n)\)。继续观察,将每一位 \(0\) 枚举值的部分直接统计(即不再枚举值)就可以做到 \(O(n \log n)\)。最优美的写法可以做到除了初始的 \(\text{Cantor Expansion}\) 外不再使用 \(\text{BIT}\)

代码

#include<cstdio>
#include<algorithm>
#define int ll
typedef long long ll;
const int mod=1e9+7;
int n,s[300005],pref[300005],fac[300005],a[300005],vis[300005],b[300005],c[300005];
inline int read() {
    register int x=0,f=1;register char s=getchar();
    while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
    while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
    return x*f;
}
inline int min(const int &x,const int &y) {return x<y? x:y;}
inline void add(int x,int d) {for(;x<=n;x+=x&(-x)) s[x]=(s[x]+d)%mod;}
inline int ask(int x) {int res=0; for(;x;x-=x&(-x)) res=(res+s[x])%mod; return res;}
signed main() {
    n=read();int ans=0,num=0,tot=0,val=0;
    fac[0]=1;
    for(register int i=1;i<=n;++i) a[i]=read(),fac[i]=fac[i-1]*i%mod;
    for(register int i=1;i<=n;++i) {
        if(a[i]==0) b[++num]=i;
        vis[a[i]]=1;
    } val=num*(num-1)/2%mod;
    for(register int i=1;i<=n;++i) if(!vis[i]) c[++tot]=i;
    for(register int i=1;i<=num;++i) pref[i]=(pref[i-1]+fac[n-b[i]])%mod;
    // for(register int i=1;i<=num;++i) tmp1[i]=(tmp1[i-1]+num-i)%mod;
    // for(register int i=1;i<=num;++i) pref[i]=(pref[i-1]+fac[num-2]*fac[n-b[i]]%mod)%mod;
    for(register int i=n;i>=1;--i) {
        if(a[i]) {
            ans=(ans+fac[n-i]*ask(a[i]))%mod;//less minus
            // bas=(bas+fac[n-i]*ask(a[i]-1))%mod;
            add(a[i],1);
        }
    }
    ans=(ans+1)*fac[num]%mod;
    // printf("%d\n",ans);
    // for(register int i=1;i<=n;++i) s[i]=0;
    int sum1=0,sum2=0;
    for(register int i=1;i<=n;++i) {
        if(!a[i]) {
            int pos=std::lower_bound(b+1,b+1+num,(ll)i)-b,sum=0;
            // for(register int j=1;j<=num;++j) {
                // sum=(sum+(ask(n)-ask(c[j]))*fac[num-1]%mod)%mod;//考虑直接维护 $\sum_{i=1}^{num} ask(c_i)$
                // sum=(sum+fac[num-2]*(num-j)%mod*pref[pos-1]%mod)%mod;//提取无关项出来
                // tmp=(tmp+min(pos-1,num-j))%mod;
                // for(register int k=1;k<pos;++k) sum=(sum+fac[num-2]*fac[n-b[k]]%mod*(num-j)%mod)%mod;
            // }
            // sum=(sum+(ask(n)*num%mod-ask(c[num]))%mod*fac[num-1]%mod)%mod;
            sum=(sum+(sum2*num-sum1)%mod*fac[num-1])%mod;
            sum=(sum+fac[num-2]*pref[pos-1]%mod*val)%mod;
            // tmp=tmp*pref[pos-1]%mod;
            ans=(ans+sum)%mod;
        }
        else {
            int pos=std::upper_bound(c+1,c+1+num,(ll)a[i])-c;
            // add(a[i],fac[n-i]*(num-pos+1));
            // add(a[i],fac[n-i]);
            if(a[i]<c[num]) sum1=(sum1+fac[n-i]*(num-pos+1))%mod;
            sum2=(sum2+fac[n-i])%mod;
        }
    }
    // for(register int i=1;i<=n;++i) s[i]=0;
    sum1=0;
    for(register int i=n;i>=1;--i) {
        if(!a[i]) {
            // int sum=0;
            // for(register int j=1;j<=num;++j) sum=(sum+fac[n-i]*ask(c[j]-1)%mod*fac[num-1]%mod)%mod;
            ans=(ans+sum1*fac[n-i]%mod*fac[num-1])%mod;
        }
        else {
            // add(a[i],1);
            int pos=std::upper_bound(c+1,c+1+num,(ll)a[i])-c;
            if(a[i]<c[num]) sum1=(sum1+(num-pos+1))%mod;
        }
    }
    printf("%lld\n",(ans%mod+mod)%mod);
    return 0;
}
posted @ 2020-10-30 08:25  tommymio  阅读(113)  评论(0编辑  收藏  举报