AT_agc034_f RNG and XOR
一个暴力的想法是 \(f_i\) 表示 \(0\to i\) 的期望步数,但是可以发现这其实不好转移。因为 \(f_j\to f_i\) 时我们不知道是否经过了 \(i\)。
这个问题实际上就是我们固定了起点 \(0\),终点不固定。但是如果我们固定终点 \(0\),\(f_i\) 表示 \(i\to 0\) 的期望就是好做的了。
转移就是 \(f_i\leftarrow 1+\sum\limits_{0\le j<2^n}f_{i\oplus j}\times p_j\),注意 \(f_0=0\)。
一个想法是写成异或卷积的形式,\((f_0,f_1,\cdots,f_{2^n-1})\times(p_0,p_1,\cdots,p_{2^n-1})=(C,f_1-1,f_2-1,\cdots,f_{2^n-1}-1)\),其中 \(C\) 是一个我们未知的常数。
注意到 \(\sum p=1\),因此我们其实是可以求出 \(C\) 的。具体地,我们知道右侧的和为 \(\sum f\),于是 \(C=f_0+2^n-1\)。
但是现在右侧有 \(f\),我们不好求逆。考虑 \(p_0\leftarrow p_0-1\),那么就有 \((f_0,f_1,\cdots,f_{2^n-1})\times(p_0-1,p_1,p_2,\cdots,p_{2^n-1})=(2^n-1,-1,-1,\cdots,-1)\)。
我们已知了两个序列,现在就可以直接求逆了……吗?
注意这两个序列的和都 \(=0\),这意味着两个序列的 \(\operatorname{FWT}\) 后的第 \(0\) 项都 \(=0\)。
但是我们至少知道,右侧的 \(\operatorname{FWT}\) 只有在第 \(0\) 位 \(=0\),因此我们其实求出了 \(\operatorname{FWT}(f)_{1\sim 2^n-1}\)。
我们不知道 \(\operatorname{FWT}(f)_0\),但是有 \(f_0=0\),考虑直接把 \(\operatorname{FWT}(f)_0\) 作为参数进行 \(\operatorname{IFWT}\),然后再通过 \(f_0=0\) 求出 \(\operatorname{FWT}(f)_0\)。
实际上这个东西写出来也是和题解中的“整体偏移”是等价的。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mod = 998244353;
const int inv2 = (mod + 1) / 2;
void Add(int &x, ll y) { x = (x + y) % mod; }
int Pow(int x, int y) {
int b = x, r = 1;
for(; y; b = (ll)b * b % mod, y /= 2) {
if(y & 1) r = (ll)r * b % mod;
}
return r;
}
namespace POLY {
void FWT(vector<int> &a) {
int n = a.size();
for(int i = 0; i < __lg(n); i++) {
for(int j = 0; j < n; j++) {
if((j >> i) & 1) continue;
int v0 = a[j], v1 = a[j ^ (1 << i)];
a[j] = (v0 + v1) % mod;
a[j ^ (1 << i)] = (v0 + mod - v1) % mod;
}
}
}
void IFWT(vector<int> &a) {
int n = a.size();
for(int i = 0; i < __lg(n); i++) {
for(int j = 0; j < n; j++) {
if((j >> i) & 1) continue;
int v0 = a[j], v1 = a[j ^ (1 << i)];
a[j] = (ll)inv2 * (v0 + v1) % mod;
a[j ^ (1 << i)] = (ll)inv2 * (v0 + mod - v1) % mod;
}
}
}
vector<int> Multiply(vector<int> a, vector<int> b) {
FWT(a), FWT(b);
for(int i = 0; i < a.size(); i++) a[i] = (ll)a[i] * b[i] % mod;
IFWT(a);
return a;
}
}
int main() {
// freopen("1.in", "r", stdin);
// freopen("1.out", "w", stdout);
ios::sync_with_stdio(0), cin.tie(0);
int n;
cin >> n;
vector<int> p (1 << n);
for(int &x : p) cin >> x;
const int inv = Pow(accumulate(p.begin(), p.end(), 0), mod - 2);
for(int &x : p) x = (ll)x * inv % mod;
p[0]--;
POLY::FWT(p);
vector<int> a (1 << n, mod - 1);
a[0] = (1 << n) - 1;
POLY::FWT(a);
vector<int> fv (1 << n, 0);
for(int i = 1; i < (1 << n); i++) {
fv[i] = (ll)a[i] * Pow(p[i], mod - 2) % mod;
}
POLY::IFWT(fv);
for(int i = 0; i < (1 << n); i++) {
cout << (fv[i] + mod - fv[0]) % mod << "\n";
}
return 0;
}
浙公网安备 33010602011771号