题解:AT_abc367_g [ABC367G] Sum of (XOR^K or 0)
不是很困难的题。
题意:给出一个长度为 \(n\) 的序列 \(a\) 和 \(m,K\),从中任意选一个子序列 \(A\) 出来,要求长度是 \(m\) 的倍数,贡献为 \((\bigoplus\limits_{i\in A} a_i)^K\)。求所有合法子序列的贡献之和。\(n\le 2\times 10^5,m\le 100,a_i\le 2^{20}\)。
做法:
首先先不管那个 \(m\) 的限制的话,这就是个 fwt 板子题,也就是求 \(\prod (1+x^{a_i})\) 的结果,这里是异或卷积。这个东西很经典,考虑 fwt 本质是在做啥,对于 \(f(x) = x^{t}\) 而言,其 fwt 的结果是 \(\sum\limits_{i=0}^{2^k-1}(-1)^{\operatorname{popcount}(t\&i)}x^i\)。这个系数只有 \(1\) 或者 \(-1\),所以最后对位乘出来的只跟 \(1,-1\) 的个数有关。所以我们可以直接先算 \(g(x) = \sum x^a_i\) 的 fwt 结果,然后我们就可以知道有多少个 \(1\) 和 \(-1\),进而就可以求出来 \(\prod (1+x^{a_i})\) fwt 后的结果,直接再 ifwt 回来即可。
考虑咋处理这个 \(m\) 的限制,我们引入一个 \(z\) 进行占位。那么就变成了 \(\prod(a+x^{a_i}z)\),对 \(x\) 做异或卷积,对 \(z\) 做循环卷积,最后取 \([z^0]\) 的项的结果。
那么我们可以把 \(x\) 当作主元,把 \(z\) 当成 fwt 的系数来处理,那么这样我们对每个项 fwt 再乘在一起后结果就是 \((1+z)^x(1-z)^{n-x}\),这个 \(x\) 可以和我们上面说的一样求出来。
然后怎么做呢?fwt 有线性性,所以我们完全可以对于每个 \(z\) 的次数单独解了然后再 ifwt 回来。所以我们发现我们如果现在解出来 \(z\) 的 \(0\) 次项的系数然后再 ifwt 回来就是对的。那么现在就要求 \((1+z)^x(1-z)^{n-x}\) 循环卷积 \(0\) 次项的系数。
注意到 \(m\) 只有 \(100\),且 \(x\) 只有 \(n\) 项,我们完全可以直接把 \((1+z)^x\) 每项的系数暴力通过 \((1+z)^{x-1}\) 解出来,\((1-z)^{n-x}\) 类似,那么枚举 \((1+z)\) 贡献的次数然后找对应的 \((1-z)^{n-x}\) 中次数即可。
最后按题目中的要求算一下答案即可。
代码:
using namespace std;
#define int long long
const int maxn = 11e5 + 5, k = 20, mod = 998244353;
int n, m, K, b[maxn];
struct Poly_xor {
vector<int> a;
void resize(int N) {
a.resize(N);
}
int size() {
return a.size();
}
int& operator[](int x) {
return a[x];
}
void fwt(int f) {
int n = size();
for (int h = 2; h <= size(); h <<= 1)
for (int i = 0; i < size(); i += h)
for (int j = i; j < i + h / 2; j++) {
int a0 = a[j], a1 = a[j + h / 2];
a[j] = (a0 + a1) * f % mod, a[j + h / 2] = (a0 - a1 + mod) * f % mod;
}
}
} f;
int qpow(int x, int k, int p) {
int res = 1;
while(k) {
if(k & 1)
res = res * x % p;
x = x * x % p, k >>= 1;
}
return res;
}
struct Poly {
vector<int> a;
void resize(int N) {
a.resize(N);
}
int size() {
return a.size();
}
int& operator[](int x) {
return a[x];
}
void shift() {
int v = a[a.size() - 1];
for (int i = a.size() - 1; i >= 1; i--)
a[i] = a[i - 1];
a[0] = v;
}
friend Poly operator+(Poly f, Poly g) {
int d = max(f.size(), g.size());
f.resize(d), g.resize(d);
for (int i = 0; i < d; i++)
f[i] = (f[i] + g[i]) % mod;
return f;
}
friend Poly operator-(Poly f, Poly g) {
int d = max(f.size(), g.size());
f.resize(d), g.resize(d);
for (int i = 0; i < d; i++)
f[i] = (f[i] - g[i] + mod) % mod;
return f;
}
} pre[maxn], suf[maxn];
int get_val(Poly f, Poly g) {
int ans = 0;
for (int i = 0; i < f.size(); i++)
ans = (ans + f[i] * g[(m - i) % m]) % mod;
return ans;
}
signed main() {
cin >> n >> m >> K;
for (int i = 1; i <= n; i++)
cin >> b[i];
pre[0].resize(m), suf[n + 1].resize(m);
pre[0][0] = suf[n + 1][0] = 1;
for (int i = 1; i <= n; i++) {
Poly t = pre[i - 1]; t.shift();
pre[i] = pre[i - 1] + t;
// for (int j = 0; j < m; j++)
// cout << pre[i][j] << " ";
// cout << endl;
}
for (int i = n; i >= 1; i--) {
Poly t = suf[i + 1]; t.shift();
suf[i] = suf[i + 1] - t;
// for (int j = 0; j < m; j++)
// cout << suf[i][j] << " ";
// cout << endl;
}
// cout << 123 << endl;
f.resize((1 << k));
for (int i = 1; i <= n; i++)
f[b[i]]++;
f.fwt(1);
// for (int i = 0; i < (1 << k); i++)
// cout << f[i] << " ";
for (int i = 0; i < (1 << k); i++) {
if(f[i] > n)
f[i] = f[i] - mod;
int t = (n + f[i]) / 2;
f[i] = get_val(pre[t], suf[t + 1]);
// cout << f[i] << endl;
}
f.fwt((mod + 1) / 2);
// for (int i = 0; i < (1 << k); i++)
// cout << f[i] << " ";
// cout << endl;
int ans = 0;
for (int i = 0; i < (1 << k); i++)
ans = (ans + qpow(i, K, mod) * f[i]) % mod;
cout << ans << endl;
return 0;
}

浙公网安备 33010602011771号