多项式相关

多项式乘法

FFT 就不讲了。

NTT 即 FFT 中的 $\omega^k_{n}$ 替换为 $g^{\frac{ k(p - 1) }{ n }}$ 即可,其中 $g$ 为 $p$ 的原根。

板子:

  1 namespace fhqAKIOI {
  2     const int M = 8e5, N = M + 5;
  3     const int mod = 998244353, G = 3, invG = 332748118;
  4     mt19937_64 rnd(chrono :: system_clock().now().time_since_epoch().count());
  5 
  6     int gen(int l, int r) {
  7         return rnd() % (r - l + 1) + l;
  8     }
  9 
 10     int w;
 11 
 12     struct Ply {
 13         int x, y;
 14         Ply(int x = 0, int y = 0) : x(x), y(y) {}
 15         Ply operator * (const Ply &a) const {
 16             Ply z;
 17             z.x = (1ll * a.x * x % mod + 1ll * a.y * y % mod * w % mod) % mod;
 18             z.y = (1ll * a.x * y % mod + 1ll * a.y * x % mod) % mod;
 19             return z;
 20         }
 21     };
 22 
 23     Ply qpow(Ply x, int y) {
 24         Ply a(1, 0);
 25 
 26         while (y) {
 27             if (y & 1) a = a * x;
 28             x = x * x;
 29             y /= 2;
 30         }
 31 
 32         return a;
 33     }
 34 
 35     int qpow(int x, int y) {
 36         int a = 1;
 37 
 38         while (y) {
 39             if (y & 1) a = 1ll * a * x % mod;
 40             x = 1ll * x * x % mod;
 41             y /= 2;
 42         }
 43 
 44         return a;
 45     }
 46 
 47     int Cipolla(int n) {
 48         n %= mod;
 49         if (!n) return 0;
 50 
 51         if (qpow(n, (mod - 1) / 2) == mod - 1) return -1;
 52 
 53         int x;
 54 
 55         while (1) {
 56             x = gen(0, mod - 1);
 57             w = (x * x % mod - n + mod) % mod;
 58             if (qpow(w, (mod - 1) / 2) == mod - 1) break;
 59         }
 60 
 61         return qpow(Ply(x, 1), (mod + 1) / 2).x;
 62     }
 63 
 64     int add(int x, int y) {
 65         return (x + y >= mod) ? (x + y - mod) : (x + y);
 66     }
 67 
 68     int sub(int x, int y) {
 69         return (x - y < 0) ? (x - y + mod) : (x - y);
 70     }
 71 
 72     void Add(int &x, int y) {
 73         x += y;
 74         if (x >= mod) x -= mod;
 75     }
 76 
 77     void Sub(int &x, int y) {
 78         x -= y;
 79         if (x < 0) x += mod;
 80     }
 81 
 82     struct NTT {
 83         int rev[N], res[N], Inv[N], Res[N], eres[N], Eres[N], rg[N];
 84         int Lst;
 85 
 86         void preinv(int n) {
 87             Inv[0] = Inv[1] = 1;
 88             for (int i = 2; i <= n; i++) {
 89                 Inv[i] = 1ll * (mod - mod / i) * Inv[mod % i] % mod;
 90             }
 91         }
 92 
 93         NTT() {
 94             Lst = 0;
 95             preinv(M);
 96         }
 97 
 98         void init(int n) {
 99             if (n == Lst) return;
100 
101             rev[0] = 0;
102 
103             for (int i = 1; i < n; i++) {
104                 rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
105             }
106 
107             Lst = n;
108         }
109 
110         void DFT(int *f, int n) {
111             for (int len = n; len > 1; len >>= 1) {
112                 int Len = len >> 1;
113                 int g = fhqAKIOI :: qpow(G, (mod - 1) / len);
114 
115                 rg[0] = 1;
116                 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod;
117 
118                 for (int l = 0; l < n; l += len) {
119                     for (int i = l; i < l + Len; i++) {
120                         int lst = 1ll * sub(f[i], f[i + Len]) * rg[i - l] % mod;
121                         Add(f[i], f[i + Len]);
122                         f[i + Len] = lst;
123                     }
124                 }
125             }
126         }
127 
128         void IDFT(int *f, int n) {
129             for (int len = 2; len <= n; len <<= 1) {
130                 int Len = len >> 1;
131                 int g = fhqAKIOI :: qpow(invG, (mod - 1) / len);
132 
133                 rg[0] = 1;
134                 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod;
135 
136                 for (int l = 0; l < n; l += len) {
137                     for (int i = l; i < l + Len; i++) {
138                         int lst = 1ll * f[i + Len] * rg[i - l] % mod;
139                         f[i + Len] = sub(f[i], lst);
140                         f[i] = add(f[i], lst);
141                     }
142                 }
143             }
144 
145             int inv = fhqAKIOI :: qpow(n, mod - 2);
146 
147             for (int i = 0; i < n; i++) {
148                 f[i] = 1ll * f[i] * inv % mod;
149             }
150         }
151 
152         void Mul(int *f, int *g, int n, int flag = 0) {
153             int m = 1;
154             while (m < n) m *= 2;
155             for (int i = n; i < m; i++) f[i] = g[i] = 0;
156             DFT(f, m);
157             DFT(g, m);
158             for (int i = 0; i < m; i++) {
159                 f[i] = 1ll * f[i] * g[i] % mod;
160             }
161             IDFT(f, m);
162             if (flag) IDFT(g, m);
163             for (int i = n; i < m; i++) f[i] = g[i] = 0;
164         }
165 
166         void getinv(int *f, int *g, int n) {
167             if (n == 1) {
168                 g[0] = fhqAKIOI :: qpow(f[0], mod - 2);
169                 return;
170             }
171 
172             getinv(f, g, (n + 1) / 2);
173 
174             for (int i = 0; i < n; i++) res[i] = f[i];
175 
176             int m = 1;
177             while (m < 2 * n) m *= 2;
178 
179             for (int i = n; i < m; i++) res[i] = g[i] = 0;
180 
181             DFT(g, m);
182             DFT(res, m);
183 
184             for (int i = 0; i < m; i++) {
185                 g[i] = 1ll * sub(2, 1ll * res[i] * g[i] % mod) % mod * g[i] % mod;
186                 res[i] = 0;
187             }
188 
189             IDFT(g, m);
190 
191             for (int i = n; i < m; i++) g[i] = 0;
192         }
193 
194         void getmod(int *f, int *g, int *h1, int *h2, int n, int m) {
195             int k = 1;
196 
197             while (k <= 2 * n) k *= 2;
198             for (int i = 0; i < k; i++) eres[i] = Eres[i] = res[i] = 0;
199 
200             k = 1;
201             while (k <= n - m + 1) k *= 2;
202 
203             for (int i = 0; i < m; i++) eres[i] = g[i];
204 
205             reverse(eres, eres + m);
206 
207             for (int i = n - m + 1; i < m; i++) eres[i] = 0;
208 
209             getinv(eres, Eres, k);
210 
211             k = 1;
212             while (k <= 2 * n) k *= 2;
213 
214             for (int i = 0; i < k; i++) res[i] = 0;
215 
216             for (int i = 0; i < n; i++) res[i] = f[i];
217             reverse(res, res + n);
218             for (int i = n - m + 1; i < k; i++) res[i] = Eres[i] = 0;
219 
220             Mul(res, Eres, k);
221 
222             for (int i = n - m + 1; i < k; i++) res[i] = 0;
223 
224             reverse(res, res + n - m + 1);
225 
226             for (int i = 0; i < n - m + 1; i++) {
227                 h1[i] = res[i];
228             }
229 
230             for (int i = 0; i < k; i++) eres[i] = 0;
231 
232             for (int i = 0; i < m; i++) eres[i] = g[i];
233 
234             Mul(eres, res, k);
235 
236             for (int i = 0; i < m - 1; i++) {
237                 h2[i] = sub(f[i], eres[i]);
238             }
239             for (int i = 0; i < k; i++) eres[i] = res[i] = Eres[i] = 0;
240         }
241 
242         void getmod(int *f, int *g, int &n, int m) {
243             if (n < m) return;
244             int k = 1;
245 
246             while (k <= 2 * n) k *= 2;
247             for (int i = 0; i < k; i++) eres[i] = Eres[i] = res[i] = 0;
248 
249             k = 1;
250             while (k <= n - m + 1) k *= 2;
251 
252             for (int i = 0; i < m; i++) eres[i] = g[i];
253 
254             reverse(eres, eres + m);
255 
256             for (int i = n - m + 1; i < m; i++) eres[i] = 0;
257 
258             getinv(eres, Eres, k);
259 
260             k = 1;
261             while (k <= 2 * n) k *= 2;
262 
263             for (int i = 0; i < k; i++) res[i] = 0;
264 
265             for (int i = 0; i < n; i++) res[i] = f[i];
266             reverse(res, res + n);
267             for (int i = n - m + 1; i < k; i++) res[i] = Eres[i] = eres[i] = 0;
268 
269             Mul(res, Eres, k);
270 
271             for (int i = n - m + 1; i < k; i++) res[i] = 0;
272 
273             reverse(res, res + n - m + 1);
274 
275             for (int i = 0; i < k; i++) eres[i] = 0;
276 
277             for (int i = 0; i < m; i++) eres[i] = g[i];
278 
279             Mul(eres, res, k);
280 
281             for (int i = 0; i < m - 1; i++) {
282                 Sub(f[i], eres[i]);
283             }
284 
285             for (int i = m - 1; i < n; i++) f[i] = 0;
286             for (int i = 0; i < k; i++) eres[i] = res[i] = Eres[i] = 0;
287 
288             n = m - 1;
289         }
290 
291         void Deriv(int *f, int n) {
292             for (int i = 0; i < n - 1; i++) {
293                 f[i] = 1ll * f[i + 1] * (i + 1) % mod;
294             }
295             f[n - 1] = 0;
296         }
297 
298         void Integral(int *f, int n) {
299             for (int i = n - 1; i > 0; i--) {
300                 f[i] = 1ll * f[i - 1] * Inv[i] % mod;
301             }
302             f[0] = 0;
303         }
304 
305         void Ln(int *f, int *g, int n) {
306             getinv(f, Res, n);
307             for (int i = 0; i < n; i++) res[i] = f[i];
308             for (int i = n; i < 2 * n; i++) res[i] = 0;
309             Deriv(res, n);
310             Mul(res, Res, 2 * n);
311             Integral(res, n);
312             for (int i = 0; i < n; i++) g[i] = res[i];
313             for (int i = 0; i < 2 * n; i++) res[i] = 0;
314         }
315 
316         void exp(int *f, int *g, int n) {
317             if (n == 1) {
318                 g[0] = 1;
319                 return;
320             }
321 
322             exp(f, g, (n + 1) / 2);
323             Ln(g, eres, n);
324             for (int i = 0; i < n; i++) {
325                 Sub(eres[i], f[i]);
326             }
327             for (int i = 0; i < n; i++) Eres[i] = g[i];
328             Mul(eres, g, n * 2);
329             for (int i = n; i < 2 * n; i++) eres[i] = 0;
330             for (int i = 0; i < n; i++) {
331                 g[i] = sub(Eres[i], eres[i]);
332             }
333             for (int i = n; i <= 2 * n; i++) g[i] = 0;
334             for (int i = 0; i < 2 * n; i++) eres[i] = Eres[i] = 0;
335         }
336 
337         void sqr(int *f, int &n) {
338             for (int i = 0; i < n; i++) res[i] = f[i];
339             for (int i = n; i < 2 * n; i++) res[i] = 0;
340             Mul(f, res, n * 2);
341             for (int i = 0; i < 2 * n; i++) res[i] = 0;
342             n = n * 2 - 1;
343         }
344 
345         void qpow(int *f, int *g, int *p, int m, int k, int n, int y) {
346             while (y) {
347                 if (y & 1) {
348                     for (int i = k; i < k + m - 1; i++) g[i] = 0;
349                     for (int i = m; i < k + m - 1; i++) f[i] = 0;
350                     Mul(f, g, k + m - 1, 1);
351                     m += k - 1;
352                     getmod(f, p, m, n);
353                 }
354 
355                 sqr(g, k);
356                 getmod(g, p, k, n);
357                 y /= 2;
358             }
359         }
360 
361         void sqrt(int *f, int *g, int n) {
362             if (n == 1) {
363                 g[0] = Cipolla(f[0]);
364                 g[0] = min(g[0], mod - g[0]);
365                 return;
366             }
367 
368             int m = (n + 1) / 2;
369             sqrt(f, g, m);
370 
371             int k = 1;
372             while (k <= m) k *= 2;
373             k *= 2;
374             for (int i = 0; i < 2 * k; i++) {
375                 Res[i] = eres[i] = Eres[i] = 0;
376             }
377             for (int i = 0; i < m; i++) Res[i] = g[i];
378             int l = m;
379             sqr(Res, m);
380             for (int i = n; i < m; i++) Res[i] = 0;
381             for (int i = 0; i < n; i++) Add(Res[i], f[i]);
382             for (int i = 0; i < l; i++) eres[i] = add(g[i], g[i]);
383             getinv(eres, Eres, k);
384             Mul(Res, Eres, k * 2);
385             for (int i = 0; i < n; i++) g[i] = Res[i];
386             for (int i = 0; i < 2 * k; i++) Res[i] = Eres[i] = eres[i] = 0;
387         }
388 
389         vector<int> tr[N];
390 
391         #define ls num << 1
392         #define rs num << 1 | 1
393         #define mid (l + r) / 2
394 
395         int A[N], B[N];
396 
397         vector<int> Mul(vector<int> f, vector<int> g) {
398             int n = f.size(), m = g.size();
399             int k = 1;
400             while (k < n + m) k *= 2;
401             for (int i = 0; i < n; i++) A[i] = f[i];
402             for (int i = 0; i < m; i++) B[i] = g[i];
403             for (int i = n; i < k; i++) A[i] = 0;
404             for (int i = m; i < k; i++) B[i] = 0;
405             Mul(A, B, k);
406             vector<int> h;
407             h.resize(n + m - 1);
408             for (int i = 0; i < h.size(); i++) h[i] = A[i];
409             return h;
410         }
411 
412         void build(int *f, int l, int r, int num) {
413             if (l == r) {
414                 tr[num].resize(2);
415                 tr[num][0] = 1;
416                 tr[num][1] = (mod - f[l]) % mod;
417                 return;
418             }
419 
420             build(f, l, mid, ls);
421             build(f, mid + 1, r, rs);
422 
423             tr[num] = Mul(tr[ls], tr[rs]);
424         }
425 
426         vector<int> MulT(vector<int> f, vector<int> g) {
427             int n = f.size(), m = g.size();
428             for (int i = 0; i < n; i++) {
429                 eres[i] = f[i];
430             }
431             for (int i = 0; i < m; i++) {
432                 Eres[i] = g[m - i - 1];
433             }
434             for (int i = n; i < n + m; i++) eres[i] = 0;
435             for (int i = m; i < n + m; i++) Eres[i] = 0;
436             Mul(eres, Eres, n + m);
437             vector<int> h;
438             h.resize(n);
439             for (int i = 0; i < n; i++) h[i] = eres[i + m - 1];
440             return h;
441         }
442 
443         int gg[N];
444 
445         vector<int> inv(vector<int> f) {
446             int n = f.size();
447             for (int i = 0; i < n; i++) Res[i] = f[i], gg[i] = 0;
448             getinv(Res, gg, n);
449             vector<int> g;
450             g.resize(n);
451             for (int i = 0; i < n; i++) g[i] = gg[i];
452             return g;
453         }
454 
455         void GetPoints(int *f, int l, int r, vector<int> F, int num) {
456             F.resize(r - l + 1);
457             if (l == r) {
458                 f[l] = F[0];
459                 return;
460             }
461 
462             vector<int> L = MulT(F, tr[rs]), R = MulT(F, tr[ls]);
463 
464             GetPoints(f, l, mid, L, ls);
465             GetPoints(f, mid + 1, r, R, rs);
466         }
467 
468         void GetMultiPoints(int *f, int *g, int *h, int n) {
469             build(g, 0, n - 1, 1);
470             vector<int> G;
471             G.resize(n);
472             for (int i = 0; i < n; i++) G[i] = f[i];
473             vector<int> F = MulT(G, inv(tr[1]));
474             GetPoints(h, 0, n - 1, F, 1);
475         }
476     } P;
477 
478     struct Poly : vector<int> {
479         template<typename ... argT>
480         Poly(argT &&... args) : vector<int>(forward<argT>(args)...) {}
481 
482         void DFT() {
483             Poly &f = *this;
484             int n = f.size();
485             Poly rg(n);
486             for (int len = n; len > 1; len >>= 1) {
487                 int Len = len >> 1;
488                 int g = fhqAKIOI :: qpow(G, (mod - 1) / len);
489                 rg[0] = 1;
490                 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod;
491                 for (int l = 0; l < n; l += len) {
492                     for (int i = l; i < l + Len; i++) {
493                         int lst = 1ll * sub(f[i], f[i + Len]) * rg[i - l] % mod;
494                         Add(f[i], f[i + Len]);
495                         f[i + Len] = lst;
496                     }
497                 }
498             }
499         }
500 
501         void IDFT() {
502             Poly &f = *this;
503             int n = f.size();
504             Poly rg(n);
505             for (int len = 2; len <= n; len <<= 1) {
506                 int Len = len >> 1;
507                 int g = fhqAKIOI :: qpow(invG, (mod - 1) / len);
508                 rg[0] = 1;
509                 for (int i = 1; i < Len; i++) rg[i] = 1ll * rg[i - 1] * g % mod;
510                 for (int l = 0; l < n; l += len) {
511                     for (int i = l; i < l + Len; i++) {
512                         int lst = 1ll * f[i + Len] * rg[i - l] % mod;
513                         f[i + Len] = sub(f[i], lst);
514                         f[i] = add(f[i], lst);
515                     }
516                 }
517             }
518             int inv = fhqAKIOI :: qpow(n, mod - 2);
519             for (int i = 0; i < n; i++) {
520                 f[i] = 1ll * f[i] * inv % mod;
521             }
522         }
523 
524         Poly inv() {
525             Poly &F = *this;
526             int n = F.size();
527             Poly g(1), f = F;
528             g[0] = fhqAKIOI :: qpow(f[0], mod - 2);
529             int m = 1;
530             while (m < n) m *= 2;
531             f.resize(m);
532             for (int len = 2; len <= m; len <<= 1) {
533                 int Len = len >> 1;
534                 Poly h(len << 1);
535                 for (int i = 0; i < len; i++) h[i] = f[i];
536                 g.resize(len << 1);
537                 g.DFT();
538                 h.DFT();
539                 for (int i = 0; i < len << 1; i++) {
540                     g[i] = 1ll * sub(2, 1ll * g[i] * h[i] % mod) * g[i] % mod;
541                 }
542                 g.IDFT();
543                 g.resize(len);
544             }
545             g.resize(n);
546             return g;
547         }
548     };
549 
550     Poly operator * (Poly f, Poly g) {
551         int n = f.size(), m = g.size();
552         int k = 1;
553         while (k < n + m - 1) k *= 2;
554         f.resize(k);
555         g.resize(k);
556         f.DFT();
557         g.DFT();
558         for (int i = 0; i < k; i++) f[i] = 1ll * f[i] * g[i] % mod;
559         f.IDFT();
560         f.resize(n + m - 1);
561         return f;
562     }
563 }
View Code

 

posted @ 2024-01-17 20:56  ORzyzRO  阅读(91)  评论(2)    收藏  举报