Educational Codeforces Round 53 (Rated for Div. 2) E. Segment Sum

应该是我做过最复杂的数位dp了
做之前已经忘了数位dp是啥了,找了一个博客复习下

然后这题不同的是求的是和,而不是个数
所以需要维护更多的元素
dp[i][j]val, count, remain 代表讨论到i位置,已经有mask(j) (0-9分别2进制表示)时他的剩下位数的总值,计数,以及前面全前导0的总值(相当于要分开计算)

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <climits>
#include <cstring>
#include <vector>
#include <list>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <algorithm>
#include <functional>
#include <assert.h>
#include <iomanip>
#include <unordered_map>
using namespace std;
const int N = 2005;
const int INF = 0x3f3f3f3f;
const int MOD = 998244353;

typedef long long ll;

long long l, r; 
int k;
int digit[20];

struct Node{
    ll val;
    ll cnt;
    ll remain;
    Node(ll a=0, ll b=0, ll c=0):val(a), cnt(b), remain(c) {}
}dp[20][1024];
ll ten[20];

Node dfs(int pos, int mask, ll num, bool lead, bool limit) {
    if(pos == -1) return Node(0, 1, 0);
    if(!lead && !limit && dp[pos][mask].val != - 1) return dp[pos][mask];

    ll ans = 0; ll cnt = 0; ll remain = 0;
    int up=limit ? digit[pos]:9;
    
    for(int i = 0; i <= up; ++i) {
        if(lead && i==0) {
            Node callback = dfs(pos-1, mask, num*10, lead, limit && i==digit[pos]); 
            ll val = callback.val; ll count = callback.cnt; ll remain2 = callback.remain;
            // printf("%d\n", count);
            remain = (remain + val + remain2) % MOD; 
        }
        else {
            int tt = mask | (1 << i);
            int ccnt = 0;
            for(int j = 0; j < 10; ++j) {
                if( (tt >> j) & 1 ) 
                    ccnt ++;
            }
            if(ccnt <= k) {
                Node callback = dfs(pos-1, tt, num * 10 + i, lead && i==0, limit && i==digit[pos]); 
                ll val = callback.val; ll count = callback.cnt; ll remain2 = callback.remain;
                // printf("%d\n", count);
                ans = (ans + i * ten[pos] % MOD * count % MOD + val ) % MOD; cnt = (cnt + count) % MOD; remain = (remain + remain2) % MOD;
            }
        }
    }

    dp[pos][mask].val = ans; dp[pos][mask].cnt = cnt; dp[pos][mask].remain = remain;
    return dp[pos][mask];
}


ll solve(ll x) {
    int pos=0;
    for(int i = 0; i < 20; ++i) {
        for(int j = 0; j < 1024; ++j) {
            dp[i][j].val = -1;
            dp[i][j].remain = 0;
            dp[i][j].cnt = 0;
        }
    }

    while(x) {
        digit[pos++]=x%10;
        x/=10;
    }
    Node tt = dfs(pos-1, 0, 0, true, true);
    return (tt.val + tt.remain) % MOD;
}

int main() {
    ten[0] = 1;
    for(int i = 1; i < 19; ++i) ten[i] = ten[i - 1] * 10 % MOD;

    // for(int i = 1; i < 19; ++i) printf("%lld ", ten[i]); printf("\n");
    
    while(~scanf("%lld %lld %d", &l, &r, &k)) {
        printf("%lld\n", (solve(r) - solve(l - 1) + MOD) % MOD);
    }
    return 0;
}
posted @ 2018-12-05 22:42  basasuya  阅读(198)  评论(0编辑  收藏  举报