CF2096H Wonderful XOR Problem(FWT, *)

Code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define sz(a) (int)(a).size()
#define all(a) (a).begin(), (a).end()
template<typename A>
string to_string(A v) {
string s = "{";
for (auto x : v) {
if (sz(s) > 1) {
s += ", ";
}
s += to_string(x);
}
return s += "}";
}
void debug_out() {
cerr << "\n";
}
template<typename T, typename... U>
void debug_out(const T& x, const U&... args) {
cerr << " " << to_string(x);
debug_out(args...);
}
#define debug(...) // cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
constexpr int mod = 998244353;
void reduce(int& x) {
(x >= mod) ? x -= mod : 0;
}
int exp(int a, int x = mod - 2) {
int p = 1;
for (; x; x >>= 1) {
if (x & 1) {
p = p * (ll)a % mod;
}
a = a * (ll)a % mod;
}
return p;
}
struct bixor {
int a, b; // a + b * x ^ i
bixor() : a(1), b(0) {}
bixor(int a, int b) : a(a), b(b) {}
};
bixor operator*(const bixor& u, const bixor& v) {
return bixor((u.a * (ll)v.a + u.b * (ll)v.b) % mod, (u.a * (ll)v.b + u.b * (ll)v.a) % mod);
}
int main() {
// freopen("H.in", "r", stdin);
ios::sync_with_stdio(false);
cin.tie(0);
int tt;
cin >> tt;
while (tt--) {
int n, m;
cin >> n >> m;
vector s(m, vector<int>());
for (int p = 0; p < m; ++p) {
s[p] = vector<int>(1 << (p + 1));
s[p][0] = 1;
for (int i = 1; i < sz(s[p]); ++i) {
s[p][i] = s[p][i - 1] + (i >> p & 1 ? -1 : 1);
}
}
auto gets = [&](int p, int x) { // might < 0
return x < 0 ? 0 : s[p][x & ((1 << (p + 1)) - 1)];
};
int mask = 1 << m;
auto fwt = [&](vector<int>& a, bool rev = false) {
for (int w = 1; w < sz(a); w *= 2) {
for (int i = 0; i < sz(a); i += w * 2) {
for (int j = 0; j < w; ++j) {
auto x = a[i | j], y = a[i | w | j];
reduce(a[i | j] = x + y);
reduce(a[i | w | j] = x - y + mod);
}
}
}
if (!rev) {
return;
}
int inv = exp(sz(a));
for (int i = 0; i < sz(a); ++i) {
a[i] = a[i] * (ll)inv % mod;
}
};
vector S(mask, -1);
vector<int> L(n), R(n);
S[0] = 1;
for (int i = 0; i < n; i++) {
cin >> L[i] >> R[i], --L[i];
S[0] = S[0] * (ll)(R[i] - L[i]) % mod;
}
debug(S[0]);
for (int p = 0; p < m; ++p) {
debug(p);
int shift = 0;
int len = 1 << (m - p - 1);
vector<bixor> D(len);
for (int i = 0; i < n; ++i) {
int a = gets(p, R[i]), b = -gets(p, L[i]);
reduce(a += mod);
reduce(b += mod);
int c = R[i] >> (p + 1), d = max(L[i], 0) >> (p + 1); // L[i] = -1 -> b = 0, shall be alright
debug(a, c, b, d);
shift ^= c;
d ^= c;
// debug(a, b, d);
D[d] = D[d] * bixor(a, b);
}
vector f(len, vector<int>(2));
for (int i = 0; i < len; ++i) {
reduce(f[i][0] = D[i].a + D[i].b);
reduce(f[i][1] = D[i].a - D[i].b + mod);
debug(i, D[i].a, D[i].b);
}
for (int w = 1; w < len; w *= 2) {
for (int i = 0; i < len; i += w * 2) {
for (int j = 0; j < w; ++j) {
auto x = f[i | j], y = f[i | w | j];
f[i | j][0] = x[0] * (ll)y[0] % mod;
f[i | j][1] = x[1] * (ll)y[1] % mod;
f[i | w | j][0] = x[0] * (ll)y[1] % mod;
f[i | w | j][1] = x[1] * (ll)y[0] % mod;
}
}
}
vector<int> g(len);
for (int i = 0; i < len; ++i) {
g[i] = f[i][0];
}
fwt(g, true);
for (int i = 0; i < len; ++i) {
if (i < (i ^ shift)) {
swap(g[i], g[i ^ shift]);
}
}
fwt(g);
for (int i = 0; i < len; ++i) {
S[i << (p + 1) | 1 << p] = g[i];
}
}
debug(S);
assert(find(all(S), -1) == S.end());
fwt(S, true);
debug(S);
int res = 0, cur = 1;
for (int i = 0; i < mask; ++i) {
res ^= S[i] * (ll)cur % mod;
reduce(cur += cur);
}
cout << res << '\n';
}
}

浙公网安备 33010602011771号