[题解] AT_ABC406_E Popcount Sum 3

洛谷题目传送门

ATcoder题目链接

题目大意

给你正整数 \(N\)\(K\)

求所有不超过 \(N\) 且满足以下条件的正整数 \(x\)(取模 \(998244353\)):

  • \(x\)\(\mathrm{popcount}\) 恰好是 \(K\)

给你 \(T\) 个测试用例,请逐个求解。

什么是 \(\mathrm{popcount}\)

对于正整数 \(y\)\(\mathrm{popcount}(y)\) 表示 \(y\) 的二进制表示中 \(1\) 位的数目。

例如,\(\mathrm{popcount}(5)=2\)\(\mathrm{popcount}(16)=1\)\(\mathrm{popcount}(25)=3\)

思路

因为题目让我们在一个范围 \(1\sim N\) 中符合条件的数的和,\(N\) 的上界是 \(2^{60}\),是个二进制数,可以拆位。

所以,我们考虑使用数位dp来做。

状态定义

\(f_{i,j}\) 表示从高望低数前 \(i\) 位二进制,恰好有 \(j\)\(1\) 的个数。
\(s_{i,j}\) 表示这些数的和。

状态转移:

\(f_{i,j}=f_{i+1,j}+f_{i+1,j+1}\)
\(s_{i,j}=s_{i+1,j}+s_{i+1,j+1}+2^i\times f_{i+1,j+1}\)

解释

\(s_{i+1,j}\)\(s_{i+1,j+1}\) 表示第 \(i\) 位为 \(0/1\) 的数字和。
\(2^i\times s_{i+1,j+1}\) 表示第 \(i\) 位为 \(1\) 的贡献(第 \(i\) 位数 \(\times\) 方案数)。

做法

  • 用记忆化搜索进行数位dp。
  • pair<int,int> 代表 \(f\)\(s\) 数组。

代码

#include <bits/stdc++.h>
#define endl '\n'
#define int long long

using namespace std;
const int N = 105;
const int mod = 998244353;
vector<int> bits;
int n, k;
pair<int,int> dp[N][N][2];
int qpow(int a, int b) {
	a %= mod;
	int ans = 1;
	while (b) {
		if (b & 1) ans = ans * a % mod;
		a = a * a % mod;
		b >>= 1;
	}
	return ans % mod;
}
pair<int, int> dfs(int pos, int cnt, int tag){
	//当前处理到pos为,有cnt个1,高位是否存在限制
	if(pos == bits.size()){
		if(cnt == k) return {1, 0};
		else return {0, 1};
	}
	if(dp[pos][cnt][tag].first != -1){
		return dp[pos][cnt][tag];
	}
	int rcnt = 0, rsum = 0;
	int mx = tag? bits[pos] : 1;
	for(int d = 0; d <= mx; d++){
		if(cnt + d > k) continue;
		auto [ncnt, nsum] = dfs(pos + 1, cnt + d, tag & (d == bits[pos]));
		if(ncnt == 0) continue;
		rcnt += ncnt;
		int len = bits.size();
		rsum += d * qpow(2, len-1-pos) % mod * ncnt % mod + nsum;
		rcnt %= mod;
		rsum %= mod;
	}
	dp[pos][cnt][tag] = {rcnt, rsum};
	return dp[pos][cnt][tag];
}
inline void init(){
	memset(dp, -1, sizeof(dp));
	bits.clear();
	return ;
}
void solve(){
	init();
	cin >> n >> k;
	int m = n;
	while(m > 0){
		bits.push_back(m % 2);
        m /= 2;
	}
	reverse(bits.begin(), bits.end());
	if(k > bits.size()) {
		cout << 0 << endl;
		return ;
	}
	return cout << dfs(0, 0, 1).second << endl,void();
}
inline int read(){int x;cin>>x;return x;}
signed main(){
	cin.tie(0)->sync_with_stdio(0);
	int _=read();
	while(_--) solve();
	return 0;
}
posted @ 2025-05-18 20:57  酱云兔  阅读(123)  评论(0)    收藏  举报