[题解] AT_ABC406_E Popcount Sum 3
题目大意
给你正整数 \(N\) 和 \(K\)。
求所有不超过 \(N\) 且满足以下条件的正整数 \(x\) 的 和(取模 \(998244353\)):
- \(x\) 的 \(\mathrm{popcount}\) 恰好是 \(K\) 。
给你 \(T\) 个测试用例,请逐个求解。
什么是 \(\mathrm{popcount}\)?
对于正整数 \(y\),\(\mathrm{popcount}(y)\) 表示 \(y\) 的二进制表示中 \(1\) 位的数目。
例如,\(\mathrm{popcount}(5)=2\),\(\mathrm{popcount}(16)=1\),\(\mathrm{popcount}(25)=3\)。
思路
因为题目让我们在一个范围 \(1\sim N\) 中符合条件的数的和,\(N\) 的上界是 \(2^{60}\),是个二进制数,可以拆位。
所以,我们考虑使用数位dp来做。
状态定义
设 \(f_{i,j}\) 表示从高望低数前 \(i\) 位二进制,恰好有 \(j\) 个 \(1\) 的个数。
\(s_{i,j}\) 表示这些数的和。
状态转移:
\(f_{i,j}=f_{i+1,j}+f_{i+1,j+1}\)
\(s_{i,j}=s_{i+1,j}+s_{i+1,j+1}+2^i\times f_{i+1,j+1}\)
解释
\(s_{i+1,j}\) 和 \(s_{i+1,j+1}\) 表示第 \(i\) 位为 \(0/1\) 的数字和。
\(2^i\times s_{i+1,j+1}\) 表示第 \(i\) 位为 \(1\) 的贡献(第 \(i\) 位数 \(\times\) 方案数)。
做法
- 用记忆化搜索进行数位dp。
- 用
pair<int,int>代表 \(f\) 和 \(s\) 数组。
代码
#include <bits/stdc++.h>
#define endl '\n'
#define int long long
using namespace std;
const int N = 105;
const int mod = 998244353;
vector<int> bits;
int n, k;
pair<int,int> dp[N][N][2];
int qpow(int a, int b) {
a %= mod;
int ans = 1;
while (b) {
if (b & 1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans % mod;
}
pair<int, int> dfs(int pos, int cnt, int tag){
//当前处理到pos为,有cnt个1,高位是否存在限制
if(pos == bits.size()){
if(cnt == k) return {1, 0};
else return {0, 1};
}
if(dp[pos][cnt][tag].first != -1){
return dp[pos][cnt][tag];
}
int rcnt = 0, rsum = 0;
int mx = tag? bits[pos] : 1;
for(int d = 0; d <= mx; d++){
if(cnt + d > k) continue;
auto [ncnt, nsum] = dfs(pos + 1, cnt + d, tag & (d == bits[pos]));
if(ncnt == 0) continue;
rcnt += ncnt;
int len = bits.size();
rsum += d * qpow(2, len-1-pos) % mod * ncnt % mod + nsum;
rcnt %= mod;
rsum %= mod;
}
dp[pos][cnt][tag] = {rcnt, rsum};
return dp[pos][cnt][tag];
}
inline void init(){
memset(dp, -1, sizeof(dp));
bits.clear();
return ;
}
void solve(){
init();
cin >> n >> k;
int m = n;
while(m > 0){
bits.push_back(m % 2);
m /= 2;
}
reverse(bits.begin(), bits.end());
if(k > bits.size()) {
cout << 0 << endl;
return ;
}
return cout << dfs(0, 0, 1).second << endl,void();
}
inline int read(){int x;cin>>x;return x;}
signed main(){
cin.tie(0)->sync_with_stdio(0);
int _=read();
while(_--) solve();
return 0;
}

浙公网安备 33010602011771号