多项式相关
多项式乘法
FFT 就不讲了。
NTT 即 FFT 中的 $\omega^k_{n}$ 替换为 $g^{\frac{ k(p - 1) }{ n }}$ 即可,其中 $g$ 为 $p$ 的原根。
板子:
1 namespace fhqAKIOI { 2 const int M = 8e5, N = M + 5; 3 const int mod = 998244353, G = 3, invG = 332748118; 4 mt19937_64 rnd(chrono :: system_clock().now().time_since_epoch().count()); 5 6 int gen(int l, int r) { 7 return rnd() % (r - l + 1) + l; 8 } 9 10 int w; 11 12 struct Ply { 13 int x, y; 14 Ply(int x = 0, int y = 0) : x(x), y(y) {} 15 Ply operator * (const Ply &a) const { 16 Ply z; 17 z.x = (1ll * a.x * x % mod + 1ll * a.y * y % mod * w % mod) % mod; 18 z.y = (1ll * a.x * y % mod + 1ll * a.y * x % mod) % mod; 19 return z; 20 } 21 }; 22 23 Ply qpow(Ply x, int y) { 24 Ply a(1, 0); 25 26 while (y) { 27 if (y & 1) a = a * x; 28 x = x * x; 29 y /= 2; 30 } 31 32 return a; 33 } 34 35 int qpow(int x, int y) { 36 int a = 1; 37 38 while (y) { 39 if (y & 1) a = 1ll * a * x % mod; 40 x = 1ll * x * x % mod; 41 y /= 2; 42 } 43 44 return a; 45 } 46 47 int Cipolla(int n) { 48 n %= mod; 49 if (!n) return 0; 50 51 if (qpow(n, (mod - 1) / 2) == mod - 1) return -1; 52 53 int x; 54 55 while (1) { 56 x = gen(0, mod - 1); 57 w = (x * x % mod - n + mod) % mod; 58 if (qpow(w, (mod - 1) / 2) == mod - 1) break; 59 } 60 61 return qpow(Ply(x, 1), (mod + 1) / 2).x; 62 } 63 64 int add(int x, int y) { 65 return (x + y >= mod) ? (x + y - mod) : (x + y); 66 } 67 68 int sub(int x, int y) { 69 return (x - y < 0) ? (x - y + mod) : (x - y); 70 } 71 72 void Add(int &x, int y) { 73 x += y; 74 if (x >= mod) x -= mod; 75 } 76 77 void Sub(int &x, int y) { 78 x -= y; 79 if (x < 0) x += mod; 80 } 81 82 struct NTT { 83 int rev[N], res[N], Inv[N], Res[N], eres[N], Eres[N], rg[N]; 84 int Lst; 85 86 void preinv(int n) { 87 Inv[0] = Inv[1] = 1; 88 for (int i = 2; i <= n; i++) { 89 Inv[i] = 1ll * (mod - mod / i) * Inv[mod % i] % mod; 90 } 91 } 92 93 NTT() { 94 Lst = 0; 95 preinv(M); 96 } 97 98 void init(int n) { 99 if (n == Lst) return; 100 101 rev[0] = 0; 102 103 for (int i = 1; i < n; i++) { 104 rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0); 105 } 106 107 Lst = n; 108 } 109 110 void DFT(int *f, int n) { 111 for (int len = n; len > 1; len >>= 1) { 112 int Len = len >> 1; 113 int g = fhqAKIOI :: qpow(G, (mod - 1) / len); 114 115 rg[0] = 1; 116 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod; 117 118 for (int l = 0; l < n; l += len) { 119 for (int i = l; i < l + Len; i++) { 120 int lst = 1ll * sub(f[i], f[i + Len]) * rg[i - l] % mod; 121 Add(f[i], f[i + Len]); 122 f[i + Len] = lst; 123 } 124 } 125 } 126 } 127 128 void IDFT(int *f, int n) { 129 for (int len = 2; len <= n; len <<= 1) { 130 int Len = len >> 1; 131 int g = fhqAKIOI :: qpow(invG, (mod - 1) / len); 132 133 rg[0] = 1; 134 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod; 135 136 for (int l = 0; l < n; l += len) { 137 for (int i = l; i < l + Len; i++) { 138 int lst = 1ll * f[i + Len] * rg[i - l] % mod; 139 f[i + Len] = sub(f[i], lst); 140 f[i] = add(f[i], lst); 141 } 142 } 143 } 144 145 int inv = fhqAKIOI :: qpow(n, mod - 2); 146 147 for (int i = 0; i < n; i++) { 148 f[i] = 1ll * f[i] * inv % mod; 149 } 150 } 151 152 void Mul(int *f, int *g, int n, int flag = 0) { 153 int m = 1; 154 while (m < n) m *= 2; 155 for (int i = n; i < m; i++) f[i] = g[i] = 0; 156 DFT(f, m); 157 DFT(g, m); 158 for (int i = 0; i < m; i++) { 159 f[i] = 1ll * f[i] * g[i] % mod; 160 } 161 IDFT(f, m); 162 if (flag) IDFT(g, m); 163 for (int i = n; i < m; i++) f[i] = g[i] = 0; 164 } 165 166 void getinv(int *f, int *g, int n) { 167 if (n == 1) { 168 g[0] = fhqAKIOI :: qpow(f[0], mod - 2); 169 return; 170 } 171 172 getinv(f, g, (n + 1) / 2); 173 174 for (int i = 0; i < n; i++) res[i] = f[i]; 175 176 int m = 1; 177 while (m < 2 * n) m *= 2; 178 179 for (int i = n; i < m; i++) res[i] = g[i] = 0; 180 181 DFT(g, m); 182 DFT(res, m); 183 184 for (int i = 0; i < m; i++) { 185 g[i] = 1ll * sub(2, 1ll * res[i] * g[i] % mod) % mod * g[i] % mod; 186 res[i] = 0; 187 } 188 189 IDFT(g, m); 190 191 for (int i = n; i < m; i++) g[i] = 0; 192 } 193 194 void getmod(int *f, int *g, int *h1, int *h2, int n, int m) { 195 int k = 1; 196 197 while (k <= 2 * n) k *= 2; 198 for (int i = 0; i < k; i++) eres[i] = Eres[i] = res[i] = 0; 199 200 k = 1; 201 while (k <= n - m + 1) k *= 2; 202 203 for (int i = 0; i < m; i++) eres[i] = g[i]; 204 205 reverse(eres, eres + m); 206 207 for (int i = n - m + 1; i < m; i++) eres[i] = 0; 208 209 getinv(eres, Eres, k); 210 211 k = 1; 212 while (k <= 2 * n) k *= 2; 213 214 for (int i = 0; i < k; i++) res[i] = 0; 215 216 for (int i = 0; i < n; i++) res[i] = f[i]; 217 reverse(res, res + n); 218 for (int i = n - m + 1; i < k; i++) res[i] = Eres[i] = 0; 219 220 Mul(res, Eres, k); 221 222 for (int i = n - m + 1; i < k; i++) res[i] = 0; 223 224 reverse(res, res + n - m + 1); 225 226 for (int i = 0; i < n - m + 1; i++) { 227 h1[i] = res[i]; 228 } 229 230 for (int i = 0; i < k; i++) eres[i] = 0; 231 232 for (int i = 0; i < m; i++) eres[i] = g[i]; 233 234 Mul(eres, res, k); 235 236 for (int i = 0; i < m - 1; i++) { 237 h2[i] = sub(f[i], eres[i]); 238 } 239 for (int i = 0; i < k; i++) eres[i] = res[i] = Eres[i] = 0; 240 } 241 242 void getmod(int *f, int *g, int &n, int m) { 243 if (n < m) return; 244 int k = 1; 245 246 while (k <= 2 * n) k *= 2; 247 for (int i = 0; i < k; i++) eres[i] = Eres[i] = res[i] = 0; 248 249 k = 1; 250 while (k <= n - m + 1) k *= 2; 251 252 for (int i = 0; i < m; i++) eres[i] = g[i]; 253 254 reverse(eres, eres + m); 255 256 for (int i = n - m + 1; i < m; i++) eres[i] = 0; 257 258 getinv(eres, Eres, k); 259 260 k = 1; 261 while (k <= 2 * n) k *= 2; 262 263 for (int i = 0; i < k; i++) res[i] = 0; 264 265 for (int i = 0; i < n; i++) res[i] = f[i]; 266 reverse(res, res + n); 267 for (int i = n - m + 1; i < k; i++) res[i] = Eres[i] = eres[i] = 0; 268 269 Mul(res, Eres, k); 270 271 for (int i = n - m + 1; i < k; i++) res[i] = 0; 272 273 reverse(res, res + n - m + 1); 274 275 for (int i = 0; i < k; i++) eres[i] = 0; 276 277 for (int i = 0; i < m; i++) eres[i] = g[i]; 278 279 Mul(eres, res, k); 280 281 for (int i = 0; i < m - 1; i++) { 282 Sub(f[i], eres[i]); 283 } 284 285 for (int i = m - 1; i < n; i++) f[i] = 0; 286 for (int i = 0; i < k; i++) eres[i] = res[i] = Eres[i] = 0; 287 288 n = m - 1; 289 } 290 291 void Deriv(int *f, int n) { 292 for (int i = 0; i < n - 1; i++) { 293 f[i] = 1ll * f[i + 1] * (i + 1) % mod; 294 } 295 f[n - 1] = 0; 296 } 297 298 void Integral(int *f, int n) { 299 for (int i = n - 1; i > 0; i--) { 300 f[i] = 1ll * f[i - 1] * Inv[i] % mod; 301 } 302 f[0] = 0; 303 } 304 305 void Ln(int *f, int *g, int n) { 306 getinv(f, Res, n); 307 for (int i = 0; i < n; i++) res[i] = f[i]; 308 for (int i = n; i < 2 * n; i++) res[i] = 0; 309 Deriv(res, n); 310 Mul(res, Res, 2 * n); 311 Integral(res, n); 312 for (int i = 0; i < n; i++) g[i] = res[i]; 313 for (int i = 0; i < 2 * n; i++) res[i] = 0; 314 } 315 316 void exp(int *f, int *g, int n) { 317 if (n == 1) { 318 g[0] = 1; 319 return; 320 } 321 322 exp(f, g, (n + 1) / 2); 323 Ln(g, eres, n); 324 for (int i = 0; i < n; i++) { 325 Sub(eres[i], f[i]); 326 } 327 for (int i = 0; i < n; i++) Eres[i] = g[i]; 328 Mul(eres, g, n * 2); 329 for (int i = n; i < 2 * n; i++) eres[i] = 0; 330 for (int i = 0; i < n; i++) { 331 g[i] = sub(Eres[i], eres[i]); 332 } 333 for (int i = n; i <= 2 * n; i++) g[i] = 0; 334 for (int i = 0; i < 2 * n; i++) eres[i] = Eres[i] = 0; 335 } 336 337 void sqr(int *f, int &n) { 338 for (int i = 0; i < n; i++) res[i] = f[i]; 339 for (int i = n; i < 2 * n; i++) res[i] = 0; 340 Mul(f, res, n * 2); 341 for (int i = 0; i < 2 * n; i++) res[i] = 0; 342 n = n * 2 - 1; 343 } 344 345 void qpow(int *f, int *g, int *p, int m, int k, int n, int y) { 346 while (y) { 347 if (y & 1) { 348 for (int i = k; i < k + m - 1; i++) g[i] = 0; 349 for (int i = m; i < k + m - 1; i++) f[i] = 0; 350 Mul(f, g, k + m - 1, 1); 351 m += k - 1; 352 getmod(f, p, m, n); 353 } 354 355 sqr(g, k); 356 getmod(g, p, k, n); 357 y /= 2; 358 } 359 } 360 361 void sqrt(int *f, int *g, int n) { 362 if (n == 1) { 363 g[0] = Cipolla(f[0]); 364 g[0] = min(g[0], mod - g[0]); 365 return; 366 } 367 368 int m = (n + 1) / 2; 369 sqrt(f, g, m); 370 371 int k = 1; 372 while (k <= m) k *= 2; 373 k *= 2; 374 for (int i = 0; i < 2 * k; i++) { 375 Res[i] = eres[i] = Eres[i] = 0; 376 } 377 for (int i = 0; i < m; i++) Res[i] = g[i]; 378 int l = m; 379 sqr(Res, m); 380 for (int i = n; i < m; i++) Res[i] = 0; 381 for (int i = 0; i < n; i++) Add(Res[i], f[i]); 382 for (int i = 0; i < l; i++) eres[i] = add(g[i], g[i]); 383 getinv(eres, Eres, k); 384 Mul(Res, Eres, k * 2); 385 for (int i = 0; i < n; i++) g[i] = Res[i]; 386 for (int i = 0; i < 2 * k; i++) Res[i] = Eres[i] = eres[i] = 0; 387 } 388 389 vector<int> tr[N]; 390 391 #define ls num << 1 392 #define rs num << 1 | 1 393 #define mid (l + r) / 2 394 395 int A[N], B[N]; 396 397 vector<int> Mul(vector<int> f, vector<int> g) { 398 int n = f.size(), m = g.size(); 399 int k = 1; 400 while (k < n + m) k *= 2; 401 for (int i = 0; i < n; i++) A[i] = f[i]; 402 for (int i = 0; i < m; i++) B[i] = g[i]; 403 for (int i = n; i < k; i++) A[i] = 0; 404 for (int i = m; i < k; i++) B[i] = 0; 405 Mul(A, B, k); 406 vector<int> h; 407 h.resize(n + m - 1); 408 for (int i = 0; i < h.size(); i++) h[i] = A[i]; 409 return h; 410 } 411 412 void build(int *f, int l, int r, int num) { 413 if (l == r) { 414 tr[num].resize(2); 415 tr[num][0] = 1; 416 tr[num][1] = (mod - f[l]) % mod; 417 return; 418 } 419 420 build(f, l, mid, ls); 421 build(f, mid + 1, r, rs); 422 423 tr[num] = Mul(tr[ls], tr[rs]); 424 } 425 426 vector<int> MulT(vector<int> f, vector<int> g) { 427 int n = f.size(), m = g.size(); 428 for (int i = 0; i < n; i++) { 429 eres[i] = f[i]; 430 } 431 for (int i = 0; i < m; i++) { 432 Eres[i] = g[m - i - 1]; 433 } 434 for (int i = n; i < n + m; i++) eres[i] = 0; 435 for (int i = m; i < n + m; i++) Eres[i] = 0; 436 Mul(eres, Eres, n + m); 437 vector<int> h; 438 h.resize(n); 439 for (int i = 0; i < n; i++) h[i] = eres[i + m - 1]; 440 return h; 441 } 442 443 int gg[N]; 444 445 vector<int> inv(vector<int> f) { 446 int n = f.size(); 447 for (int i = 0; i < n; i++) Res[i] = f[i], gg[i] = 0; 448 getinv(Res, gg, n); 449 vector<int> g; 450 g.resize(n); 451 for (int i = 0; i < n; i++) g[i] = gg[i]; 452 return g; 453 } 454 455 void GetPoints(int *f, int l, int r, vector<int> F, int num) { 456 F.resize(r - l + 1); 457 if (l == r) { 458 f[l] = F[0]; 459 return; 460 } 461 462 vector<int> L = MulT(F, tr[rs]), R = MulT(F, tr[ls]); 463 464 GetPoints(f, l, mid, L, ls); 465 GetPoints(f, mid + 1, r, R, rs); 466 } 467 468 void GetMultiPoints(int *f, int *g, int *h, int n) { 469 build(g, 0, n - 1, 1); 470 vector<int> G; 471 G.resize(n); 472 for (int i = 0; i < n; i++) G[i] = f[i]; 473 vector<int> F = MulT(G, inv(tr[1])); 474 GetPoints(h, 0, n - 1, F, 1); 475 } 476 } P; 477 478 struct Poly : vector<int> { 479 template<typename ... argT> 480 Poly(argT &&... args) : vector<int>(forward<argT>(args)...) {} 481 482 void DFT() { 483 Poly &f = *this; 484 int n = f.size(); 485 Poly rg(n); 486 for (int len = n; len > 1; len >>= 1) { 487 int Len = len >> 1; 488 int g = fhqAKIOI :: qpow(G, (mod - 1) / len); 489 rg[0] = 1; 490 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod; 491 for (int l = 0; l < n; l += len) { 492 for (int i = l; i < l + Len; i++) { 493 int lst = 1ll * sub(f[i], f[i + Len]) * rg[i - l] % mod; 494 Add(f[i], f[i + Len]); 495 f[i + Len] = lst; 496 } 497 } 498 } 499 } 500 501 void IDFT() { 502 Poly &f = *this; 503 int n = f.size(); 504 Poly rg(n); 505 for (int len = 2; len <= n; len <<= 1) { 506 int Len = len >> 1; 507 int g = fhqAKIOI :: qpow(invG, (mod - 1) / len); 508 rg[0] = 1; 509 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod; 510 for (int l = 0; l < n; l += len) { 511 for (int i = l; i < l + Len; i++) { 512 int lst = 1ll * f[i + Len] * rg[i - l] % mod; 513 f[i + Len] = sub(f[i], lst); 514 f[i] = add(f[i], lst); 515 } 516 } 517 } 518 int inv = fhqAKIOI :: qpow(n, mod - 2); 519 for (int i = 0; i < n; i++) { 520 f[i] = 1ll * f[i] * inv % mod; 521 } 522 } 523 524 Poly inv() { 525 Poly &F = *this; 526 int n = F.size(); 527 Poly g(1), f = F; 528 g[0] = fhqAKIOI :: qpow(f[0], mod - 2); 529 int m = 1; 530 while (m < n) m *= 2; 531 f.resize(m); 532 for (int len = 2; len <= m; len <<= 1) { 533 int Len = len >> 1; 534 Poly h(len << 1); 535 for (int i = 0; i < len; i++) h[i] = f[i]; 536 g.resize(len << 1); 537 g.DFT(); 538 h.DFT(); 539 for (int i = 0; i < len << 1; i++) { 540 g[i] = 1ll * sub(2, 1ll * g[i] * h[i] % mod) * g[i] % mod; 541 } 542 g.IDFT(); 543 g.resize(len); 544 } 545 g.resize(n); 546 return g; 547 } 548 }; 549 550 Poly operator * (Poly f, Poly g) { 551 int n = f.size(), m = g.size(); 552 int k = 1; 553 while (k < n + m - 1) k *= 2; 554 f.resize(k); 555 g.resize(k); 556 f.DFT(); 557 g.DFT(); 558 for (int i = 0; i < k; i++) f[i] = 1ll * f[i] * g[i] % mod; 559 f.IDFT(); 560 f.resize(n + m - 1); 561 return f; 562 } 563 }

浙公网安备 33010602011771号