快速傅里叶变换NTT\FTT
前言:拼尽全力一知半解,唯一好处:可以一知半解的背板子。<。)#)))≦认为不会考。我也这样认为。但是能多学一点也是好的。
P3803 【模板】多项式乘法(FFT)
这是我的NTT模板题
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 3000005, mod = 998244353, g = 3, gi = 332748118;
int n, m, lim = 1, L = 0, r[N];
int a[N], b[N];
int fastpow(int a, int b)
{
int res = 1;
while(b)
{
if(b & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
b >>= 1;
}
return res;
}
void NTT(int *A, int type)
{
for (int i = 0; i < lim; ++ i) if(i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < lim; mid <<= 1)
{
int wm = fastpow(type == 1 ? g : gi, (mod - 1) / (mid << 1));
for (int j = 0; j < lim; j += (mid << 1))
{
int w = 1;
for (int k = 0; k < mid; ++ k, w = ((ll)w * wm) % mod)
{
int x = A[j + k], y = (ll)w * A[j + k + mid] % mod;
A[j + k] = (x + y) % mod;
A[j + k + mid] = ((ll)x - y + mod) % mod;
}
}
}
}
int main()
{
scanf("%d %d", &n, &m);
for (int i = 0; i <= n; ++ i) scanf("%d", &a[i]), a[i] %= mod;
for (int i = 0; i <= m; ++ i) scanf("%d", &b[i]), b[i] %= mod;
while(lim <= n + m) lim <<= 1, L ++;
for (int i = 0; i < lim; ++ i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(a, 1), NTT(b, 1);
for (int i = 0; i < lim; ++ i) a[i] = ((ll)a[i] * b[i]) % mod;
NTT(a, -1);
int inv = fastpow(lim, mod - 2);
for (int i = 0; i <= n + m; ++ i)
{
printf("%d ", (ll)a[i] * inv % mod);
}
return 0;
}
【模板】高精度乘法 | A*B Problem 升级版
注意预处理的时候i是从[0,lim]闭区间
#include <bits/stdc++.h>
using namespace std;
const int N = 5000005;
const int mod = 998244353;
typedef long long ll;
int a[N], b[N], c[N], tmp[N], inv3, pow3[N], powinv3[N], r[N];
int lim = 1, L = 0, n, m;
char s[N];
int fastpow(int a, int b)
{
int res = 1;
while(b)
{
if(b & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
b >>= 1;
}
return res;
}
void init()
{
while(lim <= n + m) lim <<= 1, L ++;
inv3 = fastpow(3, mod - 2);
for (int i = 1; i <= lim; i <<= 1) pow3[i] = fastpow(3, (mod - 1) / i);
for (int i = 1; i <= lim; i <<= 1) powinv3[i] = fastpow(inv3, (mod - 1) / i);
for (int i = 0; i <= lim; ++ i)
{
r[i] = (r[i >> 1] >> 1) | ((i & 1) << L - 1);
}
return ;
}
void NTT(int *A, int type)
{
for (int i = 0; i < lim; ++ i) if(i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < lim; mid <<= 1)
{
int wn;
if(type == 1) wn = pow3[mid << 1];
else wn = powinv3[mid << 1];
for (int j = 0; j < lim; j += (mid << 1))
{
int w = 1;
for (int k = 0; k < mid; ++ k, w = (ll)w * wn % mod)
{
int x = A[j + k], y = (ll)w * A[j + k + mid] % mod;
A[j + k] = ((ll)x + y) % mod;
A[j + k + mid] = ((ll)x - y + mod) % mod;
}
}
}
if(type == -1)
{
int num = fastpow(lim, mod - 2);
for (int i = 0; i < lim; ++ i) a[i] = (ll)a[i] * num % mod;
}
return ;
}
int main()
{
scanf("%s", s + 1);
n = strlen(s + 1) - 1;
for (int i = 0; i <= n; ++ i) a[i] = s[n - i + 1] - '0';
// cout << endl;
scanf("%s", s + 1);
m = strlen(s + 1) - 1;
for (int i = 0; i <= m; ++ i) b[i] = s[m - i + 1] - '0';
// cout << endl;
init();
NTT(a, 1);
NTT(b, 1);
for (int i = 0; i < lim; ++ i) a[i] = (ll)a[i] * b[i] % mod;
NTT(a, -1);
for (int i = 0; i < lim; ++ i) c[i] = a[i];
for (int i = 0; i < lim; ++ i)
{
if(c[i] >= 10)
{
c[i + 1] += c[i] / 10;
c[i] %= 10;
}
}
int pp = lim;
while(c[pp] == 0) pp --;
for (int i = pp; i >= 0; -- i) printf("%d", c[i]);
return 0;
}
1096G - Lucky Tickets
处理r翻转数组的时候忘记右移了,调了很久。转化很有技巧。多项式的n次方其实可以先预处理结果需要多少个数来求解,然后算一次,然后将对应的每一个函数值快速幂,再反解就可以了。
#include <bits/stdc++.h>
using namespace std;
const int N = 3000005;
const int mod = 998244353;
typedef long long ll;
int quick(int a, int b)
{
int res = 1;
while(b)
{
if(b & 1) res = (ll)res * a % mod;
a = (ll)a * a % mod;
b >>= 1;
}
return res;
}
int n, lim = 1, L = 0, a[N], k, mx = 0;
int r[N], inv3, num = 0;
void NTT(int *A, int type)
{
for (int i = 0; i < lim; ++ i) if(i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < lim; mid <<= 1)
{
int wn = quick(3, (mod - 1) / (mid << 1));
if(type == -1) wn = quick(wn, mod - 2);
for (int j = 0; j < lim; j += (mid << 1))
{
int w = 1;
for (int z = 0; z < mid; ++ z, w = (ll)w * wn % mod)
{
int x = A[j + z], y = (ll)w * A[j + z + mid] % mod;
A[j + z] = (x + y) % mod;
A[j + z + mid] = ((x - y) % mod + mod) % mod;
}
}
}
if(type == -1)
{
for (int i = 0; i < lim; ++ i) A[i] = (ll)A[i] * num % mod;
}
return ;
}
int main()
{
scanf("%d %d", &n, &k);
for (int i = 1; i <= k; ++ i)
{
int x;
scanf("%d", &x);
a[x] = 1;
mx = max(mx, x);
}
mx = mx * (n / 2);
lim = 1, L = 0;
while(lim <= mx) lim <<= 1, L ++;
num = quick(lim, mod - 2);
inv3 = quick(3, mod - 2);
for (int i = 0; i <= lim; ++ i) r[i] = (r[i >> 1] >> 1)| ((i & 1) << L - 1);//请注意
NTT(a, 1);
for (int i = 0; i < lim; ++ i) a[i] = quick(a[i], n / 2);
NTT(a, -1);
int ans = 0;
for (int i = 0; i < lim; ++ i) ans = (ans + (ll)a[i] * a[i] % mod) % mod;
printf("%d", ans);
return 0;
}

浙公网安备 33010602011771号