关于数位DP
数位DP如其名字,对于一个整数,将其每个数位作为阶段进行DP,题目要求往往是在某个区间内满足某个条件的数有多少个,区间长度也往往很大,一般超过10的八次方,也就没办法一个个枚举然后check。当然,也有其他形式的,但其都有一个共同点,数需要满足的条件可以通过数位之间的关系来表达和判断。并且无后效性。
考虑为什么以数位作为阶段能优化复杂度,这是因为对于一个数,可能在我知道它的前几位或者后几位的信息后就不合法了,然而我依旧花了时间check它,而且如果我能在一个较高的数位就确定其不合法,那么一个个枚举必然浪费极多时间。
而以数位作为阶段的话,一是因为可能在某一位我们就能确定其后数位无论如何都不合法,就能省去很多无效遍历,还有就是对于两个数,其中一个数可能已经遍历过了,并且根据这个数可以直接推出另一个数的合法情况。这样不就省去了甚多情况。
例题:windy数
题目意思很明显,考虑从高到低的数位上进行dp,采用递归的形式,显然每次需要传递的参数为位置,前一个数,是否到了上界,以及有没有前导零。
然后分情况来写就行。递归边界是数位为1的时候,问题很容易解决。
#include<bits/stdc++.h>
using namespace std;
#define int long long
int dp[10][10][2][2];
vector<int> ans;
int dfs(int pos,int pre,bool reach,bool head){
int res = 0;
if(dp[pos][pre][reach][head])return dp[pos][pre][reach][head];
if(pos==1){
if(!reach){
if(head){
return 10;
}else {
for(int i = 0;i <= 9;i ++){
if(abs(i - pre) < 2)continue;
res ++;
}
}
}else {
if(head){
return ans[0];
}else {
for(int i = 0;i <= ans[0];i ++){
if(abs(i - pre) < 2)continue;
res ++;
}
}
}
return res;
}
if(!reach){
if(head){
for(int i = 0;i <= 9;i ++){
bool flg1 = 0;
if(i==0)flg1 = 1;
res += dfs(pos - 1,i,0,flg1);
}
}else {
for(int i = 0;i <= 9;i ++){
bool flg1 = 0;
if(abs(i - pre)<2)continue;
res += dfs(pos - 1,i,0,0);
}
}
}else {
if(head){
for(int i = 0;i <= ans[pos-1];i ++){
bool flg1 = 0;
bool flg2 = 0;
if(i==0)flg1 = 1;
if(i==ans[pos-1])flg2 = 1;
res += dfs(pos - 1,i,flg2,flg1);
}
}else {
for(int i = 0;i <= ans[pos-1];i ++){
bool flg1 = 0;
bool flg2 = 0;
if(abs(i - pre)<2)continue;
if(i==ans[pos-1])flg2 = 1;
res += dfs(pos - 1,i,flg2,0);
}
}
}
dp[pos][pre][reach][head] = res;
return res;
}
int sol(int x){
if(x<10)return x + 1;
ans.clear();
int cnt = 0;
while(x){
ans.push_back(x%10);
x/=10;
cnt ++;
}
int res = 0;
for(int i = 0;i <= ans[cnt - 1];i ++){
bool flg1 = 0,flg2 = 0;
if(i==0)flg1 = 1;
if(i==ans[cnt-1])flg2 = 1;
res += dfs(cnt - 1,i,flg2,flg1);
}
return res;
}
signed main(){
int a,b;
cin >> a >> b ;
cout << sol(b) - sol(a-1);
return 0;
}
例题:花神的数论题
该题只需转化一下,变为二进制,同时把答案统计方式从加法改为乘法即可。
#include<bits/stdc++.h>
using namespace std;
#define int long long
int const maxn = 100;
int const mod = 1e7+7;
vector<int> num;
int dp[maxn][maxn][2];
int dfs(int pos,int cnt,int reach){
int res = 1;
if(dp[pos][cnt][reach])return dp[pos][cnt][reach];
if(pos==1){
if(!reach){
if(cnt!=0)res = res *cnt%mod;
res = res * (cnt + 1) %mod;
}else {
for(int i = 0;i <= num[0];i ++){
if(cnt + i==0)continue;
res = res * (cnt + i)%mod;
}
}
return res;
}
if(!reach){
res = res*dfs(pos-1,cnt+1,0)%mod;
res = res*dfs(pos-1,cnt,0)%mod;
}else {
for(int i = 0;i <= num[pos- 1];i ++){
res = res *dfs(pos-1,cnt + i,(i==num[pos-1]))%mod;
}
}
dp[pos][cnt][reach] = res;
return res;
}
int sol(int x){
while(x){
num.push_back(x&1);
x/=2;
}
if(num.size()<=1){
if(num[0]){
return 1;
}else {
return 0;
}
}
int res = 1;
res = dfs(num.size() - 1,1,1)%mod;
res = res * dfs(num.size() - 1,0,0)%mod;
return res;
}
signed main(){
int n;
cin >> n;
cout << sol(n);
}

浙公网安备 33010602011771号