P3746 [六省联考 2017] 组合数问题 题解
这个题 \(k\) 很小,我们考虑从 \(k\) 入手。
我们记
现在我们来介绍几个求和变形的技巧:
一、增加枚举量
把 \(f(n,r)\) 展开 \(k\) 次,得:
二、交换枚举顺序
明显交换两个 \(\sum\) 不影响答案,所以:
三、分离无关变量
容易发现此时 \(\binom{k}{j}\) 和内层的 \(\sum\) 没有关系了,提出来:
然后我们把式子整理一下:
我们会发现,第二个 \(\sum\) 正好是 \(f(n-1,r-j)\)!所以:
我们使用一个经典套路:给他扩充一维,凑成矩阵乘法:
这个时候我们把 \(f(n)\) 看成一个 \(1\times (k+1)\) 的矩阵,那么我们实际上需要凑出一个矩阵 \(D\),使其满足:
\(D\) 是一个 \((k+1)\times (k+1)\) 的矩阵。
对比原式,只需要令 \(D[r-j][j]=\binom{k}{j}\) 即可。
这样就能用矩阵乘法了。
有 \(f_n=f_0\times D^n\)。\(D\) 已经构造出来了,现在的问题就是构造 \(f_0\)。
我们有
很明显,只有当 \(r=0\) 的时候式子是 \(1\),其余为 \(0\)。
所以,\(f_0\) 是一个 \(1\times(k+1)\) 的矩阵,且只有 \(f_0[0][0]=1\),其余都是 \(0\)。
但还有一个问题:\(r-j\) 可能 \(<0\)。
我们考虑证明:\(j>0\) 时,\(f(i,0,-j)=f(i,0,k-j)\)。
把 \(f(i,0,k-j)\) 展开,得:
实际上,因为 \(\binom{ik}{0k-j}=\binom{ik}{-j}=0\),所以原式
这样,如果 \(r-j<0\),我们就可以直接给 \(D[r-j+k][j]\) 加上 \(\binom{k}{j}\) 了(不能直接赋值,因为一个位置可能会被重复算多次)。此时,\(D\) 的两维就都是 \([0,k]\) 之间的数了。
然后我们就可以直接进行矩阵乘法了(由于预处理时已经解决了负数的情况,所以直接从 \(0\) 到 \(k\) 枚举下标即可)。
这样我们就做完了。时间复杂度 \(O(k^3\log n)\)。
代码:
#include<bits/stdc++.h>
//#pragma GCC optimize("Ofast")
#define gt getchar
#define pt putchar
#define y1 y233
//typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
//typedef __int128 lll;
//typedef __uint128_t ulll;
const int N=55;
using namespace std;
inline bool __(char ch){return ch>=48&&ch<=57;}
inline int read(){
int x=0;bool sgn=0;char ch=gt();
while(!__(ch)&&ch!=EOF){sgn|=(ch=='-');ch=gt();}
while(__(ch)){x=(x<<1)+(x<<3)+(ch-48);ch=gt();}
return sgn?-x:x;
}
template<class T>
inline void print(T x){
static char st[70];short top=0;
if(x<0)pt('-');
do{st[++top]=x>=0?(x%10+48):(-(x%10)+48),x/=10;}while(x);
while(top)pt(st[top--]);
}
template<class T>
inline void printsp(T x){
static char st[70];short top=0;
if(x<0)pt('-');
do{st[++top]=x>=0?(x%10+48):(-(x%10)+48),x/=10;}while(x);
while(top)pt(st[top--]);pt(32);
}
template<class T>
inline void println(T x){
static char st[70];short top=0;
if(x<0)pt('-');
do{st[++top]=x>=0?(x%10+48):(-(x%10)+48),x/=10;}while(x);
while(top)pt(st[top--]);pt(10);
}
inline void put_str(string s){
int siz=s.size();
for(int i=0;i<siz;++i) pt(s[i]);
printf("\n");
}
int n,p,k,r,C[N][N];
struct Matrix{
int n,m;
ll val[N][N];
Matrix(int _n=0,int _m=0){
n=_n,m=_m;
memset(val,0,sizeof(val));
}
}f0,D;
inline Matrix operator*(const Matrix &a,const Matrix &b){
Matrix c(a.n,b.m);
for(int i=0;i<c.n;++i){
for(int k=0;k<a.m;++k){
for(int j=0;j<c.m;++j){
c.val[i][j]=(c.val[i][j]+a.val[i][k]*b.val[k][j]%p)%p;
}
}
}
return c;
}
inline Matrix ksm(Matrix a,int b){
Matrix c=a;
b--;
while(b){
if(b&1)c=c*a;
a=a*a,b>>=1;
}
return c;
}
signed main(){
n=read(),p=read(),k=read(),r=read();
C[0][0]=1;
for(int i=1;i<=k;++i){
C[i][0]=1;
for(int j=1;j<=k;++j) C[i][j]=(C[i-1][j]+C[i-1][j-1])%p;
}
f0.n=1,f0.m=k+1;
f0.val[0][0]=1;
for(int j=1;j<=k;++j) f0.val[0][j]=0;
D.n=D.m=k+1;
for(int r=0;r<=k;++r){
for(int j=0;j<=k;++j){
if(r>=j) D.val[r][r-j]+=C[k][j];
else D.val[r][r-j+k]+=C[k][j];
}
}
Matrix F=f0*ksm(D,n);
println(F.val[0][r]);
return 0;
}