Hard nim
https://vjudge.net/problem/黑暗爆炸-4589
FWT模板题
struct FWT{
const int inv2 = ksm(2, mod - 2);
int l;
void init(int n) {
l = 2;
while(l < n) l <<= 1;//l <= n?
}
inline int qmod(int x) {
while(x >= mod) x -= mod;
return x;
}
//第一行for(int k = n; k > 1; k >>= 1) 与 for(int k = 2; k <= n; k <<= 1) 等价,无影响 -> k的枚举顺序可以任意调换。
//从大到小枚举复杂度更优?
void fwt(vector<int> &a, int opt, int cas) {
a.resize(l);//l + 1?
int x, y;
for(int k = 2; k <= l; k <<= 1) {
int mid = k >> 1;
for(int i = 0; i < l; i += k) {
for(int j = i, up = i + mid; j < up; ++ j) {
//or:
if(cas == 1) a[j + mid] = qmod(a[j + mid] + opt * a[j] + mod);
//and:
else if(cas == 2) a[j] = qmod(a[j] + opt * a[j + mid] + mod);
//xor:
else if(cas == 3) {
x = a[j], y = a[j + mid];
a[j] = qmod(x + y);
a[j + mid] = qmod(x - y + mod);
if(opt == -1) a[j] = 1ll * a[j] * inv2 % mod, a[j + mid] = 1ll * a[j + mid] * inv2 % mod;
}
//xnor:
else {
x = a[j], y = a[j + mid];
a[j] = qmod(y - x + mod);
a[j + mid] = qmod(x + y);
if(opt == -1) a[j] = 1ll * a[j] * inv2 % mod, a[j + mid] = 1ll * a[j + mid] * inv2 % mod;
}
}
}
}
}
}f;
//FMT的常数比FWT小一点
struct FMT{
const int inv2 = ksm(2, mod - 2);
int l;
void init(int n) {
l = 2;
while(l < n) l <<= 1;
}
inline int qmod(int x) {
while(x >= mod) x -= mod;
return x;
}
void fmt(vector<int> &a, int opt, int cas) {
a.resize(l);
int x, y;
for(int k = 1; k < l; k <<= 1) {
//or:
if(cas == 1) {
for(int i = 0; i < l; ++ i) if(~i & k) a[i | k] = qmod(a[i | k] + opt * a[i] + mod);
}
//and:
else if(cas == 2) {
for(int i = l - 1; i >= 0; -- i) if(i & k) a[i ^ k] = qmod(a[i ^ k] + opt * a[i] + mod);
}
//xor:
else {
for(int i = 0; i < l; ++ i) {
if(~i & k) {
x = a[i], y = a[i | k];
a[i] = qmod(x + y);
a[i | k] = qmod(x - y + mod);
if(opt == -1) a[i] = 1ll * a[i] * inv2 % mod, a[i | k] = 1ll * a[i | k] * inv2 % mod;
}
}
}
}
}
}f2;
int a[maxn];
void run() {
int n, m;
for(int i = 2; i < maxn; ++ i) a[i] = 1;
for(int i = 2; i < maxn; ++ i) if(a[i])
for(int j = i + i; j < maxn; j += i) a[j] = 0;
while(~scanf("%d %d", &n, &m)) {
//n堆石子,每堆个数不超过m的素数
//f_{i ^ j = 0} = a_i * a_j//a_i = 1(i是素数)
vector<int> b(m + 1);
for(int i = 0; i <= m; ++ i) b[i] = a[i];
f2.init(m + 1);
f2.fmt(b, 1, 3);
for(int i = 0; i < f2.l; ++ i) b[i] = ksm(b[i], n);//每个数可以选n次
f2.fmt(b, -1, 3);
printf("%d\n", b[0]);
}
return ;
}

浙公网安备 33010602011771号