OIFC 2026省选 0119

*歌 song

大胆容斥

序列可被看成若干段,每一段元素不重,且除首尾段外,其余长度 \(k\)。考虑枚举首段长度 \(l\),我们想钦定这是第一段长度最小的划分方式,于是考虑把划分点向前平移一定长度,则每一段都被分为两部分,每一段对应的部分值域相同,互为重排。不断这样做,发现每一段都被分为了若干部分,设有 \(c\) 个部分,则容斥系数 \((-1)^{c+1}\)

\(f_i\) 表示已有的部分长 \(i\),转移考虑当前部分长度 \(j\)。值得注意:跨过对应在首段中跨过 \(1\) 和对应在尾段中跨过 \(n\) 的部分会有额外的排列数倍贡献;第一部分长度应 \(\ge k-l\),否则平移到该位置首段长度反而增加。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
typedef pair<ll,ll> pll;
template<typename T>
void chkmin(T &x,const T &y){x=min(x,y);}
template<typename T>
void chkmax(T &x,const T &y){x=max(x,y);}
const int inf=0x3f3f3f3f;
const ll infll=0x3f3f3f3f3f3f3f3f;
int MOD;
void add(int &x,int y){
    x+=y;
    if(x>=MOD) x-=MOD;
}
int qpow(int a,ll b){
    int mul=1;
    while(b){
        if(b&1) mul=(ll)mul*a%MOD;
        a=(ll)a*a%MOD;
        b>>=1;
    }
    return mul;
}
const int K=755;
int fact[K],C[K][K];
int k,f[K],g[K];
ll n;
int A(int n,int m){
    return (ll)C[n][m]*fact[m]%MOD;
}
void __INIT__(){}
void __SOLVE__(){
    scanf("%lld%d%d",&n,&k,&MOD);
    fact[0]=1;
    for(int i=1;i<=k;i++) fact[i]=(ll)i*fact[i-1]%MOD;
    C[0][0]=1;
    for(int i=1;i<=k;i++){
        C[i][0]=C[i][i]=1;
        for(int j=1;j<i;j++) C[i][j]=(C[i-1][j-1]+C[i-1][j])%MOD;
    }
    int ans=0;
    for(int l=0;l<k;l++){
        // printf("l=%d\n",l);
        int r=(n-l)%k;
        ll m=(n-l)/k;
        for(int i=1;i<=k;i++) g[i]=qpow(fact[i],m)%MOD;
        // for(int i=0;i<=k;i++) printf("%d ",g[i]);printf("\n");
        f[0]=MOD-1;
        for(int i=1;i<k-l;i++) f[i]=0;
        for(int i=k-l;i<=k;i++){
            f[i]=0;
            for(int j=1;j<=i;j++){
                int tmp=g[j];
                if(k-i<l&&k-i+j>=l) tmp=(ll)tmp*A(j,l-(k-i))%MOD;
                if(k-i+j<l) tmp=(ll)tmp*fact[j]%MOD;
                if(i-j<r&&i>=r) tmp=(ll)tmp*A(j,r-(i-j))%MOD;
                if(i<r) tmp=(ll)tmp*fact[j]%MOD;
                // printf("%d ",tmp);
                add(f[i],MOD-(ll)tmp*f[i-j]%MOD*C[i][j]%MOD);
            }
            // printf("\n");
        }
        // for(int i=0;i<=k;i++) printf("%d ",f[i]);printf("\n");
        add(ans,f[k]);
    }
    printf("%d\n",ans);
    return;
}
int main(){
    #ifndef JZQ
    freopen("song.in","r",stdin);
    freopen("song.out","w",stdout);
    #endif
    int T=1;
    scanf("%d",&T);
    __INIT__();
    while(T--) __SOLVE__();
    return 0;
}

*火力大喵 dameow

分治

首先注意到答案为全局 \(\max\) - 全局 \(\min\),容斥,拿总数 - 无 \(\max\) 数 - 无 \(\min\) 数 + 无 \(\min,\max\) 数,这样变成对一个 \(01\) 矩阵,查询有多少个边框只覆盖 \(0\)

考虑像旅行者一样分治,假设当前分治区域长宽为 \(w,h\),想办法做到复杂度 \(\mathcal{O}(\min(w,h)^2+wh)\),每次切开长边,两侧对称,分别做一下,再用乘法原理合并。不妨只看左侧,设 \(h_i\) 表示分治中心第 \(i\) 个点左侧有多少个相邻的 \(0\)利用悬线法思想,对于上下边界在 \(x,y\) 的矩形,\(h\) 更小的地方做贡献,这样可以避免另一侧不够长,不妨设 \(h_x\le h_y\)。对当前区域里每个点,预处理其上下方有多少个 \(0\)。接下来枚举 \(x\),再枚举左边界,此时合法的 \(y\) 是一个区间(忽略 \(h\) 大小关系),打一个差分标记。所有标记打好后,扫一遍 \(y\),统计答案。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
typedef pair<ll,ll> pll;
template<typename T>
void chkmin(T &x,const T &y){x=min(x,y);}
template<typename T>
void chkmax(T &x,const T &y){x=max(x,y);}
const int inf=0x3f3f3f3f;
const ll infll=0x3f3f3f3f3f3f3f3f;
const int MOD=998244353;
void add(int &x,int y){
    x+=y;
    if(x>=MOD) x-=MOD;
}
int qpow(int a,ll b){
    int mul=1;
    while(b){
        if(b&1) mul=(ll)mul*a%MOD;
        a=(ll)a*a%MOD;
        b>>=1;
    }
    return mul;
}
const int N=2005;
int n,m,bd[N][N],a[N][N],f[N][N],g[N][N],h1[N],h2[N],dif[N],up[N][N],down[N][N];
ll calc(int l=1,int r=m,int u=1,int d=n){
    if(l==r&&u==d) return !a[l][u];
    // printf("[%d,%d] [%d,%d]\n",l,r,u,d);
    ll ans=0;
    if(r-l>d-u){
        int mid=(l+r)>>1;
        for(int j=l;j<=r;j++) up[u-1][j]=down[d+1][j]=0;
        for(int i=u;i<=d;i++){
            h1[i]=l-1;
            for(int j=l;j<=mid;j++){
                if(a[i][j]) h1[i]=j,up[i][j]=0;
                else up[i][j]=up[i-1][j]+1;
            }
            h2[i]=r+1;
            for(int j=r;j>mid;j--){
                if(a[i][j]) h2[i]=j,up[i][j]=0;
                else up[i][j]=up[i-1][j]+1;
            }
        }
        for(int i=d;i>=u;i--) for(int j=l;j<=r;j++){
            if(a[i][j]) down[i][j]=0;
            else down[i][j]=down[i+1][j]+1;
        }
        for(int x=u;x<=d;x++){
            for(int y=u;y<=d;y++) dif[y]=0;
            for(int j=mid;j>h1[x];j--) dif[x]++,dif[max(x-up[x][j],u-1)]--,dif[min(x+down[x][j],d+1)]--;
            for(int y=x-1;y>=u;y--){
                dif[y]+=dif[y+1];
                if(h1[y]<h1[x]) f[y][x]=dif[y];
            }
            for(int y=x+1;y<=d;y++){
                dif[y]+=dif[y-1];
                if(h1[y]<=h1[x]) f[x][y]=dif[y];
            }
            f[x][x]=mid-h1[x];
            for(int y=u;y<=d;y++) dif[y]=0;
            for(int j=mid+1;j<h2[x];j++) dif[x]++,dif[max(x-up[x][j],u-1)]--,dif[min(x+down[x][j],d+1)]--;
            for(int y=x-1;y>=u;y--){
                dif[y]+=dif[y+1];
                if(h2[y]>h2[x]) g[y][x]=dif[y];
            }
            for(int y=x+1;y<=d;y++){
                dif[y]+=dif[y-1];
                if(h2[y]>=h2[x]) g[x][y]=dif[y];
            }
            g[x][x]=h2[x]-mid-1;
        }
        // for(int x=u;x<=d;x++){
        //     for(int y=u;y<=d;y++) printf("%d ",f[x][y]);
        //     printf("\n");
        // }
        // for(int x=u;x<=d;x++){
        //     for(int y=u;y<=d;y++) printf("%d ",g[x][y]);
        //     printf("\n");
        // }
        for(int x=u;x<=d;x++) for(int y=x;y<=d;y++) ans+=(ll)f[x][y]*g[x][y];
        // printf("ans %lld\n",ans);
        return ans+calc(l,mid,u,d)+calc(mid+1,r,u,d);
    }
    else{
        int mid=(u+d)>>1;
        for(int j=u;j<=d;j++) up[j][l-1]=down[j][r+1]=0;
        for(int j=l;j<=r;j++){
            h1[j]=u-1;
            for(int i=u;i<=mid;i++){
                if(a[i][j]) h1[j]=i,up[i][j]=0;
                else up[i][j]=up[i][j-1]+1;
            }
            h2[j]=d+1;
            for(int i=d;i>mid;i--){
                if(a[i][j]) h2[j]=i,up[i][j]=0;
                else up[i][j]=up[i][j-1]+1;
            }
        }
        for(int j=r;j>=l;j--) for(int i=u;i<=d;i++){
            if(a[i][j]) down[i][j]=0;
            else down[i][j]=down[i][j+1]+1;
        }
        for(int x=l;x<=r;x++){
            // printf("x=%d\n",x);
            for(int y=l;y<=r;y++) dif[y]=0;
            for(int i=mid;i>h1[x];i--) dif[x]++,dif[max(x-up[i][x],l-1)]--,dif[min(x+down[i][x],r+1)]--;
            for(int y=x-1;y>=l;y--){
                dif[y]+=dif[y+1];
                if(h1[y]<h1[x]) f[y][x]=dif[y];
            }
            for(int y=x+1;y<=r;y++){
                dif[y]+=dif[y-1];
                if(h1[y]<=h1[x]) f[x][y]=dif[y];
            }
            f[x][x]=mid-h1[x];
            for(int y=l;y<=r;y++) dif[y]=0;
            for(int i=mid+1;i<h2[x];i++) dif[x]++,dif[max(x-up[i][x],l-1)]--,dif[min(x+down[i][x],r+1)]--;
            for(int y=x-1;y>=l;y--){
                dif[y]+=dif[y+1];
                if(h2[y]>h2[x]) g[y][x]=dif[y];
            }
            for(int y=x+1;y<=r;y++){
                dif[y]+=dif[y-1];
                if(h2[y]>=h2[x]) g[x][y]=dif[y];
            }
            g[x][x]=h2[x]-mid-1;
        }
        // for(int x=l;x<=r;x++){
        //     for(int y=l;y<=r;y++) printf("%d ",f[x][y]);
        //     printf("\n");
        // }
        // for(int x=l;x<=r;x++){
        //     for(int y=l;y<=r;y++) printf("%d ",g[x][y]);
        //     printf("\n");
        // }
        for(int x=l;x<=r;x++) for(int y=x;y<=r;y++) ans+=(ll)f[x][y]*g[x][y];
        // printf("ans %lld\n",ans);
        return ans+calc(l,r,u,mid)+calc(l,r,mid+1,d);
    }
    return 114514;
}
void __INIT__(){}
void __SOLVE__(){
    scanf("%d%d",&n,&m);
    int mx=-inf,mn=inf;
    for(int i=1;i<=n;i++){
        for(int j=1;j<=m;j++){
            scanf("%d",&bd[i][j]);
            chkmin(mn,bd[i][j]),chkmax(mx,bd[i][j]);
        }
    }
    ll ans=(ll)n*(n+1)/2*m*(m+1)/2;
    if(mn==mx){
        printf("0 %lld\n",ans);
        return;
    }
    for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) a[i][j]=(bd[i][j]==mn);
    // printf("---\n");
    // for(int i=1;i<=n;i++){
    //     for(int j=1;j<=m;j++) printf("%d ",a[i][j]);
    //     printf("\n");
    // }
    ans-=calc();
    for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) a[i][j]=(bd[i][j]==mx);
    // printf("---\n");
    // for(int i=1;i<=n;i++){
    //     for(int j=1;j<=m;j++) printf("%d ",a[i][j]);
    //     printf("\n");
    // }
    ans-=calc();
    for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) a[i][j]=(bd[i][j]==mn||bd[i][j]==mx);
    // printf("---\n");
    // for(int i=1;i<=n;i++){
    //     for(int j=1;j<=m;j++) printf("%d ",a[i][j]);
    //     printf("\n");
    // }
    ans+=calc();
    printf("%d %lld\n",mx-mn,ans);
}
int main(){
    #ifndef JZQ
    freopen("dameow.in","r",stdin);
    freopen("dameow.out","w",stdout);
    #endif
    int T=1;
    // scanf("%d",&T);
    __INIT__();
    while(T--) __SOLVE__();
    return 0;
}

另一种做法。

把四条边的限制用无限制、一条边限制、两条边限制、三条边限制表达,对各种情况分别计算。

*松鼠威廉梦游仙境 squirrel

不要把四种情况求和看成取 \(\max\)

对于固定的 \(x,y\),根据插板法可知买法数

\[\binom{x+n-1}{n-1}\binom{y+m-1}{m-1} \]

事先把 \(n,m\)\(1\)

根据 Lucas 定理的结论:

\[\binom{n}{m}=\prod_i\binom{n_i}{m_i} \]

其中 \(x_i\) 表示 \(x\)\(p\) 进制表示下第 \(i\) 位。

在模 \(p\) 意义下,只有最低位会产生贡献

考虑从低往高数位 DP,维护 \(ax+by\) 的进位。设当前位进位 \(j\),下一位进位 \(j'\),则 \(k_i+pj'=ax_i+by_i+j\le(a+b)(p-1)+j\),归纳可得 \(j\le a+b\),所以状态数 \(\mathcal{O}((a+b)\log_p{k})\),转移复杂度 \(\mathcal{O}(p)\)

考虑转移时组合数的贡献,看起来 \(x+n\) 的进位会有问题,但实际上,如果有了进位,那么 \(x_i+n_i-p<n_i\),组合数 \(=0\),所以永远不会加出来进位。

总复杂度 \(\mathcal{O}((a+b)\log k\frac{p}{\log p})\)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
typedef pair<ll,ll> pll;
typedef __int128_t int128;
template<typename T>
void chkmin(T &x,const T &y){x=min(x,y);}
template<typename T>
void chkmax(T &x,const T &y){x=max(x,y);}
const int inf=0x3f3f3f3f;
const ll infll=0x3f3f3f3f3f3f3f3f;
int MOD;
struct FastMod{
    unsigned long long b, m;
    FastMod() = default;
    FastMod(unsigned long long m) : b(((unsigned __int128)1 << 64) / m), m(m) {}
    unsigned long long operator()(unsigned long long a) {
        unsigned long long q = (unsigned __int128)a * b >> 64;
        unsigned long long r = a - q * m;
        return r >= m ? r - m : r;
    }
} fastmod;
void add(int &x,int y){
    x+=y;
    if(x>=MOD) x-=MOD;
}
int qpow(int a,ll b){
    int mul=1;
    while(b){
        if(b&1) mul=(ll)mul*a%MOD;
        a=(ll)a*a%MOD;
        b>>=1;
    }
    return mul;
}
const int N=65,W=105,P=1000005;
ll n,m,k;
int a,b,f[N][W],nbit[N],mbit[N],kbit[N];
int c0,c1,c2,c3;
int fact[P],invfact[P],inv[P];
int Cx[P],Cy[P];
int C(int n,int m){
    if(n<m) return 0;
    return fastmod(fastmod((ll)fact[n]*invfact[m])*invfact[n-m]);
}
void __INIT__(){}
void __CLEAR__(){
    fastmod=FastMod(MOD);
    fact[0]=1;
    for(int i=1;i<MOD;i++) fact[i]=fastmod((ll)i*fact[i-1]);
    invfact[MOD-1]=MOD-1;
    for(int i=MOD-1;i>0;i--) invfact[i-1]=fastmod((ll)i*invfact[i]);
    for(int i=1;i<MOD;i++) inv[i]=fastmod((ll)invfact[i]*fact[i-1]);
}
void __SOLVE__(){
    scanf("%lld%lld%lld%d%d%d",&n,&m,&k,&a,&b,&MOD);
    scanf("%d%d%d%d",&c0,&c1,&c2,&c3);
    __CLEAR__();
    n--,m--;
    int L=0;
    for(int128 mul=1;mul<k;mul*=MOD,L++);
    for(int i=0;i<=L+1;i++) nbit[i]=mbit[i]=kbit[i]=0;
    for(int i=0;n;n/=MOD,i++) nbit[i]=n%MOD;
    for(int i=0;m;m/=MOD,i++) mbit[i]=m%MOD;
    for(int i=0;k;k/=MOD,i++) kbit[i]=k%MOD;
    int inva=qpow(a,MOD-2),invb=qpow(b,MOD-2);
    int tmp=MOD-fastmod((ll)a*invb);
    for(int i=0;i<=L+1;i++) for(int j=0;j<=a+b;j++) f[i][j]=0;
    for(int x=0,y=fastmod((ll)kbit[0]*invb);x+nbit[0]<MOD;x++){
        if(x) add(y,tmp);
        // printf("x=%d y=%d\n",x,y);
        if(y+mbit[0]<MOD)
            add(f[0][(a*x+b*y)/MOD],fastmod(fastmod((ll)C(x+nbit[0],nbit[0])*C(y+mbit[0],mbit[0]))*(c0+fastmod((ll)x*c1)+fastmod((ll)y*c2)+fastmod(fastmod((ll)x*y)*c3))));
    }
    for(int i=0;i<=L;i++){
        Cx[0]=Cy[0]=1;
        for(int x=1;x+nbit[i+1]<MOD;x++) Cx[x]=fastmod(fastmod((ll)Cx[x-1]*(x+nbit[i+1]))*inv[x]);
        for(int y=1;y+mbit[i+1]<MOD;y++) Cy[y]=fastmod(fastmod((ll)Cy[y-1]*(y+mbit[i+1]))*inv[y]);
        for(int j=0;j<=(a+b);j++){
            if(!f[i][j]) continue;
            int y=kbit[i+1]-j;
            if(y<0) y+=MOD;
            y=fastmod((ll)y*invb);
            for(int x=0;x+nbit[i+1]<MOD;x++){
                if(x) add(y,tmp);
                if(y+mbit[i+1]<MOD)
                    add(f[i+1][(a*x+b*y+j)/MOD],fastmod(fastmod((ll)Cx[x]*Cy[y])*f[i][j]));
            }
        }
    }
    printf("%d\n",f[L+1][0]);
}
int main(){
    #ifndef JZQ
    freopen("squirrel.in","r",stdin);
    freopen("squirrel.out","w",stdout);
    #endif
    int T=1;
    scanf("%d",&T);
    __INIT__();
    while(T--) __SOLVE__();
    return 0;
}
posted @ 2026-01-20 22:18  SmpaelFx  阅读(0)  评论(0)    收藏  举报