2025.4.16的模拟赛”三项式“题解
三项式 
题目描述
给定 \(n,m,l,r\),求满足以下条件的整数序列 \(A=[a_1,a_2,\cdots,a_n]\) 的个数,答案对 \(10^9+7\) 取模:
- 
\(\forall i,a_i\in [l,r]\) 
- 
\(S=\sum a,S\) 各个数位的和与各个数位的平方和对 \(m\) 同余。 
\(1\le n,m\le 20,0\le l\le r<10^{1000}\) 。
题解
直接对S数位dp会有组合数的贡献,难以同步处理。
于是考虑同时对 \(n\) 个数进行数位dp,先将每个数减去 \(nl\),最后在加上 \(nl\)。
设 \(dp_{i,m,c,p}\) 表示从高到低考虑第 \(len\sim i\) 位,第 \(i-1\to i\) 的预设进位(因为尚未更新 \(i-1\) 的情况,但可以通过“预设+检验”的方式实现更新)为 \(c\) 有 \(m\) 个数达到了 \(R\) 上界,\(S1-S2\bmod m=p\)(判断两个数相等时只需要存差值),每次转移枚举给个数的和,以及多少个数离开了上界,复杂度 \(\mathcal O(mn^5\lg r)\)。
现在要让所有数 \(\le R\),考虑容斥,钦定 \(x\) 个数 \(>R\),就不需要再存上界了,状态:\(dp_{i,c,p}\) 复杂度 \(\mathcal O(mn^4\lg r)\)。
令 \(v=t+c+bias\),其中 \(bias\) 表示偏移下界及容斥下界带来的额外贡献,写出转移式子:\(g_t\times f_{\lfloor\frac v{10}\rfloor,p}\to f'_{c,p-(v\bmod 10)^2+v\bmod 10}\)。其中 \(g_t\) 表示 \(n\) 个\([0,9]\) 的数总和为 \(t\)。发现当 \(t\) 在模 \(10\) 意义下,\(f\) 仅有第一维改变,于是可以预处理出来,令 \(s_{r,c',p}=\sum_{t\in\{t|t\bmod 10=r,\frac{t'+c+bias}{10}-\frac{t'+bias}{10}=c'\}}g_tf_{\frac{v}{10},t}\)。可以发现 \(0\le c\le 2\)
于是 $f'{p,p-(v\bmod 10)^2+v\bmod 10}\leftarrow s-\frac{t'+bias}{10},p} $,这样 \(t'\in[0,9]\) 减少了一个 \(n\)。
相当于一个和优化,复杂度 \(\mathcal O(mn^3\lg r)\)。
代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int CC=1e3+5,NN=25,MOD=1e9+7;
int n,m;
string l,r;
struct BigInt{
	int a[CC*NN],len;
	int&operator[](int x){return a[x];}
	BigInt(){memset(a,0,sizeof a);}
	BigInt(string s){
		len=s.size();
		for(int i=0;i<(int)s.size();i++)a[len-i-1]=s[i]-'0';
	}
	void operator-=(BigInt b){
		for(int i=0;i<len;i++){
			a[i]-=b[i];
			if(a[i]<0)a[i]+=10,a[i+1]--;
		}
		while(len&&a[len-1]==0)len--;
		return;
	}
	BigInt operator+(BigInt b){
		BigInt c;c.len=max(len,b.len);
		for(int i=0;i<max(len,b.len);i++){
			c[i]+=a[i]+b[i];
			if(c[i]>10)c[i+1]++,c[i]-=10;
		}
		while(c[c.len]){
			if(c[c.len]>10)c[c.len+1]++,c[c.len]-=10;
			c.len++;
		}
		return c;
	}
	void operator+=(BigInt b){
		*this=*this+b;
		return;
	}
	BigInt operator+(int v){
		BigInt c;c.len=len;
		for(int i=0;i<max(1ll,len);i++){
			c[i]+=a[i]+(i==0)*v;
			if(c[i]>10)c[i+1]++,c[i]-=10;
		}
		while(c[c.len]){
			if(c[c.len]>10)c[c.len+1]++,c[c.len]-=10;
			c.len++;
		}
		return c;
	}
	BigInt operator*(int v){
		BigInt c;c.len=len;
		for(int i=0;i<len;i++){
			c[i]+=a[i]*v;
			c[i+1]+=c[i]/10;
			c[i]%=10;
		}
		while(c[c.len]){
			c[c.len+1]+=c[c.len]/10;
			c[c.len]%=10;
			c.len++;
		}
		return c;
	}
	bool operator<(BigInt b){
		if(len!=b.len)return len<b.len;
		for(int i=len-1;i>=0;i--)if(a[i]!=b[i])return a[i]<b[i];
		return false;
	}
	void Print(){
		if(!len)cout<<0;
		for(int i=len-1;i>=0;i--)cout<<a[i];cout<<"\n";
	}
}L,R,nR,bia;
int ans,fac[10*NN],ifac[NN*10];
int QP(int a,int b){
	int c=1;
	for(;b;b>>=1){
		if(b&1)c=1ll*a*c%MOD;
		a=1ll*a*a%MOD;
	}
	return c;
}
void PreWork(int n){
	fac[0]=1;for(int i=1;i<=n;i++)fac[i]=1ll*fac[i-1]*i%MOD;
	ifac[n]=QP(fac[n],MOD-2);
	for(int i=n-1;i>=0;i--)ifac[i]=1ll*ifac[i+1]*(i+1)%MOD;
	return;
}
int C(int n,int m){
	return 1ll*fac[n]*ifac[m]%MOD*ifac[n-m]%MOD;
}
int Mod(int x,int mod){
	return (x%mod+mod)%mod;
}
int dp[CC][2][NN][NN],s[10][2][4][NN];//第3维表示i-1到i的进位 
int f[NN][NN*10];
void DP(){
	int len=nR.len;
	memset(dp,0,sizeof dp);
	dp[len][1][0][0]=1;//第len-1位是最高位,不应该再进位了 
	for(int i=len;i>=1;i--){//第i向第i-1位的转移 
		memset(s,0,sizeof s);
		for(int lim=0;lim<=1;lim++)
			for(int c=0;c<=3;c++)
				for(int p=0;p<m;p++)
					for(int t=0;t<=9*n;t++)//
						(s[t%10][lim][c][p]+=f[n][t]*dp[i][lim][(bia[i-1]+t)/10+c][p])%=MOD;
		for(int lim=0;lim<=1;lim++)//第i位时的上限状态 
			for(int c=0;c<=n+1;c++)//第i-2向第i-1位的进位 
				for(int p=0;p<m;p++)//第i~第len位的差 
					for(int t=0;t<=9;t++){//枚举第i-1位n个数的总和 
						int bit=(c+bia[i-1]+t)%10;
						if(lim&&bit>nR[i-1])continue;
						int _lim=lim&&bit==nR[i-1];
						int _p=Mod(p-bit*bit+bit,m);
						(dp[i-1][_lim][c][_p]+=s[t][lim][(c+t+bia[i-1])/10-(t+bia[i-1])/10][p])%=MOD;
					}
	}
	return;
}
int Sign(int x){return x&1?-1:1;}
signed main(){
	cin>>n>>m>>l>>r;
	PreWork(10*n);
	L=BigInt(l),R=BigInt(r);
	nR=R*n;
	R-=L;
	bia=L*n;
	f[0][0]=1;
	for(int i=1;i<=n;i++)
		for(int j=0;j<=9*i;j++)
			for(int k=0;k<=9;k++)
				if(j>=k)(f[i][j]+=f[i-1][j-k])%=MOD;
	for(int i=0;i<=n;i++){
		DP();
		(ans+=Sign(i)*C(n,i)*(dp[0][0][0][0]+dp[0][1][0][0]))%=MOD;
		bia+=R+1;
		if(nR<bia)break;
	}
	cout<<Mod(ans,MOD);
	return 0;
}

 
                
            
         浙公网安备 33010602011771号
浙公网安备 33010602011771号