字越少事越大?
题目描述
海豚想问有多少个n位k进制数模m为0且至少某一位为x的数的个数。
输入格式
一行四个数分别表示n,k,m,x
输出格式
输出答案对998244353取模的结果
好明显的数位dp!构建数组\(f_{ijp}\) ,\(i\)表示现在是第几位(以最低位为第\(0\)位算),\(j\)表示当前模\(m\)的余数,\(p\)为当前是否已经出现过\(x\)。显然有\(\mathbf{{\large f_{i-1,(j+t\cdot k^{i-1})\bmod{m},p\vee t==x}+=f_{i,j,p}}}\),其中\(t\)在\([0,k)\)范围内枚举。
显然这只能拿一半的分。
那么在位数范围如此之大的情况下我们如何加速dp呢?
诶嘿,矩阵快速幂!
考虑构建转移矩阵,由于\(p\)无论如何都只可能为\(0\)或\(1\),所以我们可以给它展开,构建一个\(2m*2m\)的转移矩阵,同时我们不能再枚举了,所以要提前算出各个余数有多少个数可以得到。我们就愉快的解决了……吗?
还有一个很重要的问题啊,我们的转移方程是要参考\(k^i\)的,它的值会随数位的变化而变化,转移矩阵都不同,怎么快速幂?
这时我们注意到:\(\mathbf{note:k}\)与\(\mathbf{m}\)互质。
根据欧拉定定理
\(若\gcd(a, p) =1,则a^{\varphi (p)}\equiv1\pmod{m}\)
所以我们可以计算出在\([0,\varphi(m))\)位上的所有转移矩阵,并将它们乘起来作为底数进行快速幂,剩余部分暴力乘起来就好了。由于\(m\)很小,所以不用担心会超时。
最后注意初始状态以及首位判\(0\)。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const long long MOD=998244353;
long long n,q,x,m,phi,c[50];
struct matrix{
long long r,c,m[110][110];
}tra[60];
matrix f,b;
matrix operator *(matrix a,matrix b)
{
matrix s;
s.r=a.r,s.c=b.c;
for(long long i=0;i<s.r;i++)
for(long long j=0;j<s.c;j++)
{
s.m[i][j]=0;
for(long long k=0;k<a.c;k++)
s.m[i][j]=(s.m[i][j]+a.m[i][k]*b.m[k][j]%MOD)%MOD;
}
return s;
}
matrix fm(matrix bot,long long pow)
{
matrix s;
s.c=s.r=bot.c;
for(long long i=0;i<s.r;i++)
for(long long j=0;j<s.c;j++)
s.m[i][j]=(i==j);
for(;pow;pow>>=1,bot=bot*bot)
if(pow&1)
s=s*bot;
return s;
}
long long fm(long long bot,long long pow,long long p)
{
long long s=1;
for(;pow;pow>>=1,bot=bot*bot%p)
if(pow&1)
s=s*bot%p;
return s;
}
void getphi(long long u,long long &a)
{
a=1;
for(long long i=2;i*i<=u;i++)
{
if(u%i)
continue;
a*=i-1;
u/=i;
while(u%i==0)
a*=i,u/=i;
}
if(u>1)
a*=u-1;
}
int main()
{
scanf("%lld %lld %lld %lld",&n,&q,&m,&x);
for(long long i=0;i<m&&i<q;i++)
c[i]=(q-1-i)/m+1;
getphi(m,phi);
b.c=b.r=2*m;
for(long long i=0;i<2*m;i++)
b.m[i][i]=1;
for(long long i=0,t=1;i<phi;i++,t=t*q%m)
{
tra[i].c=tra[i].r=2*m;
for(long long j=0;j<m;j++)
{
for(long long k=0;k<m;k++)
tra[i].m[j][(j+t*k%m)%m]+=c[k]-(x%m==k);
tra[i].m[j][(j+x%m*t%m)%m+m]=1;
}
for(long long j=0;j<m;j++)
for(long long k=0;k<m;k++)
tra[i].m[j+m][(j+t*k%m)%m+m]+=c[k];
b=b*tra[i];
}
f.r=1,f.c=2*m;
if(x)
{
for(long long i=0;i<m;i++)
f.m[0][i*fm(q,n-1,m)%m]+=c[i]-(!i)-(x%m==i);
f.m[0][x%m*fm(q,n-1,m)%m+m]=1;
}
else
for(long long i=0;i<m;i++)
f.m[0][i*fm(q,n-1,m)%m]+=c[i]-(!i);
long long t=n-1;
while(t%phi)
f=f*tra[(t-1)%phi],t--;
f=f*fm(b,t/phi);
printf("%lld\n",f.m[0][m]);
return 0;
}