MTT 小记

有的时候,我们想要让多项式乘法结果中的系数,对一些不是那么常规的模数取模,或者对任意质数 \(p\) 取模,这时候我们可以用 MTT 来解决。

MTT 有两种,一种是三模数 NTT,另一种是拆系数 FFT。

  • 三模数 NTT

三模数 NTT 就是,选择三个著名的 NTT 模数,比如 998244353、1004535809、469762049,它们的原根都是 3。然后我们求出模它们意义下的答案,接着 CRT 合并。

具体地,如果我们求出了 \(x \equiv x_1 \pmod A, x \equiv x_2 \pmod B, x \equiv x_3 \pmod C\)

那么有 \(x_1 + k_1A = x_2 + k_2B\)。则可以解得 \(k_1 = \frac{x_2 - x_1}{A} \pmod B\)

于是我们把 \(k_1\) 代回 \(x \equiv x_1 + k_1A\) 中可以得到 \(x \equiv x_4 \pmod{AB}\)

然后我们用类似的方法,有 \(x_4 + k_4AB = x_3 + k_3C\)

解得 \(k_4 = \frac{x_3 - x_4}{AB} \pmod C\),代回来可以得到 \(x \equiv x_5 \pmod{ABC}\)

由于 \(ABC\) 大于真实值,所以 \(x_5 \mod p\) 即为答案。

  • 拆系数 FFT

考虑让 \(P(x) = tA(x) + B(x), Q(x) = tC(x) + D(x)\),其中 \(A(x), C(x)\) 为除 \(t\) 后的商,\(B(x), D(x)\) 为除 \(t\) 后的余数。

\(P(x)Q(x) = (tA(x) + B(x))(tC(x) + D(x))\)

拆开以后大力算一下就行,其中 \(t\) 一般取 \(\sqrt p\) 级别的二的正整数次幂。

好像用个什么共轭优化就可以只需要 4 次变换,比朴素的 7 次要优很多。

  • 代码实现(三模数 NTT)
#include <set>
#include <map>
#include <cmath>
#include <queue>
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#define maxN 100010
const long long g = 3;
const long long mod[3] = {998244353, 1004535809, 469762049};
struct Poly{ long long a[maxN << 2]; long long N; } P, Q;
long long rev[maxN << 2];
long long read ()
{
	long long x = 0, Fu = 1;
	char c = getchar();
	while(c < '0' || c > '9')
	{
		if(c == '-') Fu = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9')
	{
		x = x * 10 + (c - '0');
		c = getchar();
	}
	return x * Fu;
}
void Swap (long long &x, long long &y) { long long t = x; x = y; y = t; }
long long Abs (long long x) { return (x >= 0) ? x : (-x); }
long long Min (long long x, long long y) { return x < y ? x : y; }
long long Max (long long x, long long y) { return x > y ? x : y; }
long long Pow (long long x, long long y, long long pre)
{
	if(!y) return 1;
	long long res = Pow(x, y >> 1, pre);
	if(!(y & 1)) return res * res % mod[pre];
	else return res * res % mod[pre] * x % mod[pre];
}
void NTT (Poly &A, long long typ, long long Limit, long long pre)
{
	long long G = Pow(g, mod[pre] - 2, pre);
	for(long long i = 0;i < Limit; i++)
		if(i < rev[i]) Swap(A.a[i], A.a[rev[i]]);
	for(long long mid = 1;mid < Limit;mid <<= 1)
	{
		long long Wn = Pow((typ == 1) ? g : G, (mod[pre] - 1) / (mid << 1), pre);
		for(long long j = 0;j < Limit;j += (mid << 1))
		{
			long long w = 1;
			for(long long k = 0;k < mid; k++, w = w * Wn % mod[pre])
			{
				long long x = A.a[j + k], y = A.a[j + mid + k] * w % mod[pre];
				A.a[j + k] = (x + y) % mod[pre]; A.a[j + mid + k] = (x - y + mod[pre]) % mod[pre];
			}
		}
	}
}
Poly NTT_Multi (Poly P, Poly Q, long long pre)
{
	long long N = P.N, M = Q.N; long long Len = N + M;
	long long Limit = 1, l = 0; while(Limit <= N + M) Limit <<= 1, l++;
	for(long long i = 0;i < Limit; i++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
	NTT(P, 1, Limit, pre); NTT(Q, 1, Limit, pre);
	for(long long i = 0;i < Limit; i++)
		P.a[i] = P.a[i] * Q.a[i] % mod[pre];
	NTT(P, -1, Limit, pre);
	long long inv_N = Pow(Limit, mod[pre] - 2, pre);
	P.N = Len;
	for(long long i = 0;i <= P.N; i++)
		P.a[i] = P.a[i] * inv_N % mod[pre];
	return P;
}
int main ()
{
	P.N = read(); Q.N = read(); long long p = read();
	for(long long i = 0;i <= P.N; i++) P.a[i] = read() % p;
	for(long long i = 0;i <= Q.N; i++) Q.a[i] = read() % p;
	Poly G1 = NTT_Multi(P, Q, 0);
	Poly G2 = NTT_Multi(P, Q, 1);
	Poly G3 = NTT_Multi(P, Q, 2);
	long long A = mod[0], B = mod[1], C = mod[2];
	for(long long i = 0;i <= P.N + Q.N; i++)
	{
		long long x1 = G1.a[i], x2 = G2.a[i], x3 = G3.a[i];
		long long k1 = (x2 - x1 + B) % B * Pow(A, B - 2, 1) % B;
		long long x4 = (x1 + k1 * A);
		long long k4 = (x3 - (x4 % C) + C) % C * Pow(A * B % C, C - 2, 2) % C;
		printf("%lld ", (x4 + k4 * A % p * B) % p);
	}
	return 0;
}
posted @ 2024-07-05 20:00  abcdeffa  阅读(56)  评论(0)    收藏  举报