Algorithm: Polynomial Multiplication -- Fast Fourier Transform / Number-Theoretic Transform (English version)

Intro:

This blog will start with plain multiplication, go through Divide-and-conquer multiplication, and reach FFT and NTT.

The aim is to enable the reader (and myself) to fully understand the idea.

Template question entrance: Luogu P3803 【模板】多项式乘法(FFT)


Plain multiplication

Assumption: Two polynomials are \(A(x)=\sum_{i=0}^{n}a_ix^i,B(x)=\sum_{i=0}^{m}b_ix^i\)

Prerequisite knowledge:

Knowledge of junior high school mathematics

The simplest method is to multiply term by term and then combine like terms, written as the formula:

If \(C(x)=A(x)B(x)\), then \(C(x)=\sum_{i=0}^{n+m}c_ix^i\), where \(c_i=\sum_{j=0}^ia_jb_{i-j}\).

So a plain multiplication is generated, see the code (\(b\) array omitted with some useless techniques).

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int x;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,x=c&15;else k=x=0;
	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
	return k?x:-x;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define N (2000010)
int n,m,a[N],b,c[N];
signed main(){
	Rd(n),Rd(m);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m){Rd(b);Frn1(j,0,n)c[i+j]+=b*a[j];}
	Frn1(i,0,n+m)wr(c[i]),Ps;
	exit(0);
}

Time complexity: \(O(nm)\) (If\(m=O(n)\), then \(O(n^2)\))

Memory complexity: \(O(n)\)

Results:

Expected, so we need to optimize it.


Divide-and-conquer multiplication (Fake)

P.s This part describes the Divide-and-conquer method of FFT, which is still different from the exact FFT, so you can skip it if you have already mastered the Divide-and-conquer idea.

Let \(n\) be the smallest positive integer power of \(2\) that is strictly greater than both the degrees of \(A(x),B(x)\), and we write \(A(x)=\sum_{i=0}^{n-1}a_ix^i,B(x)=\sum_{i=0}^{n-1}b_ix^i\), where the unexisted coefficients are made \(0\).

Prerequisite knowledge:

  The idea of Divide-and-conquer

Now consider how to optimize multiplication.

Try to separate two polynomials according to the parity of the index of \(x\)

\(A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2),B(x)=B^{[0]}(x^2)+xB^{[1]}(x^2)\),

where \(A^{[0]}(x)=\sum_{i=0}^{n/2-1}a_{2i}x^i,A^{[1]}(x)=\sum_{i=0}^{n/2-1}a_{2i+1}x^i\), and \(B^{[0]}(x)\) and \(B^{[1]}(x)\) are similar.

Therefore, the two polynomials are split into four polynomials, each with degree \(<n/2\).

We let \(A=A(x),A^{[0]}=A^{[0]}(x^2),A^{[1]}=A^{[1]}(x^2)\), and similar for \(B\) and others,

then \(AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}\).

A Divide-and-conquer algorithm can be found here: split two polynomials in half, then recursively do \(4\) polynomial multiplications, and finally combine them together (polynomial addition is \(O(n)\) anyway)

P.s As \(A^{[0]}=A^{[0]}(x^2)\) and \(A^{[1]}=A^{[1]}(x^2)\), the combination process is alternating. Here is the code. (In the code, the \(n\) above is replaced by the variable s, and vector is used to save memory)

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int x;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,x=c&15;else k=x=0;
	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
	return k?x:-x;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s; 
Vct a,b,c;
void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
void mlt(Vct&a,Vct&b,Vct&c,int n);
signed main(){
	Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	mlt(a,b,c,s);
	Frn1(i,0,n+m)wr(c[i]),Ps;
	exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c,int n){
	int n2(n>>1);
	Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
	if(n==1){c[0]=a[0]*b[0];return;}
	Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
	mlt(a0,b0,ab0,n2),mlt(a1,b1,ab1,n2);
	Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
	mlt(a0,b1,ab0,n2),mlt(a1,b0,ab1,n2),add(ab0,ab1,abm);
	Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}

Results:

even worse

Why's that? Because the Time complexity is still \(O(n^2)\).

\(\textit{Proof. } T(n)=4T(n/2)+f(n)\), in which \(f(n)=O(n)\) the time complexity of polynomial addition.

Using the Master Theorem with \(a=4,b=2,\log_ba=\log_2 4=2>1\), we have \(T(n)=O(n^{\log_ba})=O(n^2)\).

So, let's continue optimizing


Divide-and-conquer multiplication (Real)

Let's consider how to optimize the "fake" one.

An intro question: Try to find an algorithm to multiply linear expressions \(ax+b\) and \(cx+d\) with only \(3\) multiplication steps.

Let's expand the multiplication: \((ax+b)(cx+d)=acx^2+(ad+bc)x+bd\), there seems to be \(4\) multiplication steps used.

Hence, if we can only use \(3\) multiplication steps, then \(ad+bc\) should cost only one.

Let's add all coefficients together: \(ac+ad+bc+bd=(a+b)(c+d)\),

and here is the answer! Use \(3\) multiplication steps to calculate \(ac,bd,(a+b)(c+d)\) respectively, and the \(x\) coefficient is just \(ad+bc=(a+b)(c+d)-ac-bd\)

Let's go back to the original question

As \(AB=(A^{[0]}+xA^{[1]})(B^{[0]}+xB^{[1]})=A^{[0]}B^{[0]}+x(A^{[1]}B^{[0]}+A^{[0]}B^{[1]})+x^2A^{[1]}B^{[1]}\),

we can use the similar method to reduce one multiplication step: \(A^{[1]}B^{[0]}+A^{[0]}B^{[1]}=(A^{[0]}+A^{[1]})(B^{[0]}+B^{[1]})-A^{[0]}B^{[0]}-A^{[1]}B^{[1]}\)

Here is the code:

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int x;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,x=c&15;else k=x=0;
	while(isdigit(Gc(c)))x=(x<<1)+(x<<3)+(c&15);
	return k?x:-x;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
typedef vector<int> Vct;
int n,m,s;
Vct a,b,c;
void add(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]+b[i];}
void mns(Vct&a,Vct&b,Vct&c){Frn0(i,0,c.size())c[i]=a[i]-b[i];}
void mlt(Vct&a,Vct&b,Vct&c);
signed main(){
	Rd(n),Rd(m),a.resize(s=1<<int(log2(max(n,m))+1)),b.resize(s),c.resize(s<<1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	mlt(a,b,c);
	Frn1(i,0,n+m)wr(c[i]),Ps;
	exit(0);
}
void mlt(Vct&a,Vct&b,Vct&c){
	int n(a.size()),n2(a.size()>>1);
	Vct a0(n2),a1(n2),b0(n2),b1(n2),ab0(n),ab1(n),abm(n);
	if(n==1){c[0]=a[0]*b[0];return;}
	Frn0(i,0,n2)a0[i]=a[i<<1],a1[i]=a[i<<1|1],b0[i]=b[i<<1],b1[i]=b[i<<1|1];
	mlt(a0,b0,ab0),mlt(a1,b1,ab1);
	Frn0(i,0,n)c[i<<1]=ab0[i]+(i?ab1[i-1]:0);
	add(a0,a1,a0),add(b0,b1,b0),mlt(a0,b0,abm),mns(abm,ab0,abm),mns(abm,ab1,abm);
	Frn0(i,0,n-1)c[i<<1|1]=abm[i];
}

Results

Better than fake DC multiplication, but even worse than plain multiplication...

Let's calculate the time complexity of this algorithm:

\(T(n)=3T(n/2)+f(n)\), in which \(f(n)=O(n)\).

Using Master Theorem with \(a=3,b=2,\log_ba=\log_2 3\approx1.58>1\), so \(T(n)=O(n^{\log_ba})=O(n^{\log_2 3})\).

Hmm...so why is it even worse than plain multiplication?

Reason 1. The constant factor of DC multiplication is too high.

Reason 2. In \(\#5\) test case, we have \(n=1,m=3\cdot 10^6\), then \(O(n^{\log_2 3})\) is really worse than \(O(nm)\)...

So, our FFT is eventually coming!


Fast Fourier Transform

Fairly Frightening Transform

Let \(n\) be the smallest positive integer power of \(2\) greater than \(\deg A(x)+\deg B(x)\) and we write \(A(x)=\sum_{i=0}^{n-1}a_ix^i,B(x)=\sum_{i=0}^{n-1}b_ix^i\).

Prerequisite knowledge:

  The idea of Divide-and-conquer

  Complex number basics

Linear algebra basics (not strictly required)

Part 1: To representations of the polynomial

1. Coefficient expressions

For a polynomial \(A(x)=\sum_{i=0}^{n-1}a_ix^i\), its coefficient expression is a vector \(\pmb{a}=\left[\begin{matrix}a_0\\a_1\\\vdots\\a_{n-1}\end{matrix} \right]\)

In coefficient expressions, the time complexities of the following methods are:

  1. Evaluation at a point: \(O(n)\)

  2. Addition: \(O(n)\)

  3. Multiplication: plain \(O(n^2)\), DC \((n^{\log_2 3})\)

P.s When calculating polynomial multiplication \(C(x)=A(x)B(x)\), the corresponding coefficient expression \(\pmb{c}\) is defined as the convolution of \(\pmb{a}\) and \(\pmb{b}\), written as \(\pmb{c}=\pmb{a}\bigotimes\pmb{b}\).

2. Point-valued expressions

The point-valued expression of a polynomial \(A(x)\) with \(\deg A<n\) is a set of \(n\) points: \(\{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\}\)

We can use \(n\) evaluations to convert a coefficient expression to a point-valued expression with a list of \((x_0,x_1,\cdots,x_{n-1})\) in time complexity of \(O(n^2)\) as shown:

\(\left[\begin{matrix}1&x_0&x_0^2&\cdots&x_0^{n-1}\\1&x_1&x_1^2&\cdots&x_1^{n-1}\\\vdots&\vdots&\vdots&\ddots&\vdots\\1&x_{n-1}&x_{n-1}^2&\cdots&x_{n-1}^{n-1}\end{matrix} \right]\left[\begin{matrix}a_0\\a_1\\\vdots\\a_{n-1}\end{matrix} \right]=\left[\begin{matrix}y_0\\y_1\\\vdots\\y_{n-1}\end{matrix} \right]\)

The matrix is written as \(V(x_0,x_1,\cdots,x_{n-1})\), named Vandermonde matrix, so the formula is simplified to \(V(x_0,x_1,\cdots,x_{n-1})\pmb{a}=\pmb{y}\).

Using Lagrangian formulas, a point-valued expression can be converted back into a coefficient expression in \(O(n^2)\) time, a process called interpolation.

With two polynomials in point-valued expressions with the same list of \((x_0,\cdots,x_{n-1})\), the time complexity of following methods are:

  1. Addition: \(O(n)\) (Adding the \(y_i\) value respectively)

  2. Multiplication \(O(n)\) (similar)

This is one central idea of FFT powered polynomial multiplication: with carefully chosen \(x_i\) values, we can achieve evaluation in \(O(n\log n)\), multiplication in \(O(n)\), and finally interpolation in \(O(n\log n)\).

So what are those \(x_i\) values?

Part 2: Complex roots of unity

The \(n\)-th roots of unity are exactly \(n\) complex numbers \(\omega\) that satisfy \(\omega^n=1\), written as:

\(\omega_n^k=e^{2\pi ik/n}=\cos(2\pi k/n)+i\sin(2\pi k/n)\).

We can plot \(n\)-th roots of unity as \(n\) vertices of a regular \(n\)-gon inscribed in the unit circle on the complex plane. For example, the following graph shows the \(8\)-th roots of unity.

There is a pattern: \(\omega_n^j\omega_n^k=\omega_n^{j+k}=\omega_n^{(j+k)\mod n}\). Specifically, \(\omega_n^{-1}=\omega_n^{n-1}\).

Three other important lemmas.

\(\text{Lemma 1. }\) For all integers \(n\geqslant 0,k\geqslant 0,d>0\), we have \(\omega_{dn}^{dk}=\omega_n^k\).

\(\textit{Proof. }\omega_{dn}^{dk}=(e^{2\pi i/dn})^{dk}=(e^{2\pi i/n})^k=\omega_n^k.\square\)

\(\text{Lemma 2. }\) For all even number \(n\) and integer \(k\), we have \((\omega_n^k)^2=(\omega_n^{k+n/2})^2=\omega_{n/2}^k\).

\(\textit{Proof. }(\omega_n^k)^2=\omega_n^{2k},(\omega_n^{k+n/2})^2=\omega_n^{2k+n}=\omega_n^{2k}\). Lastly, \(\omega_n^{2k}=\omega_{n/2}^k\) by \(\text{Lemma 1}.\square\)

\(\text{Lemma 3. }\) For all integers \(n,k\geqslant 0\) such that \(n\nmid k\), we have \(\sum_{j=0}^{n-1}(\omega_n^k)^j=0\).

\(\textit{Proof. }\) When \(n\nmid k\), we have \(\omega_n^k\neq 1\), so \(\sum_{j=0}^{n-1}(\omega_n^k)^j=\frac{1-(\omega_n^k)^n}{1-\omega_n^k}=\frac{1-\omega_n^{nk}}{1-\omega_n^k}=\frac{1-1}{1-\omega_n^k}=0.\square\) (Question: why is \(n\nmid k\) necessary?)

The above properties of roots of unity are the essence of FFT optimization.

Part 3: Discrete Fourier Transform

Recall the definition of \(n\), which is a power of \(2\). DFT is just the evaluation of coefficient expressed \(A(x)\) on \(n\)-th roots of unity. We write the Vandermonde matrix as

\(V_n=V(\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1})=\left[\begin{matrix}1&1&1&1&\cdots&1\\1&\omega_n&\omega_n^2&\omega_n^3&\cdots&\omega_n^{n-1}\\1&\omega_n^2&\omega_n^4&\omega_n^6&\cdots&\omega_n^{2(n-1)}\\1&\omega_n^3&\omega_n^6&\omega_n^9&\cdots&\omega_n^{3(n-1)}\\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots\\1&\omega_n^{n-1}&\omega_n^{2(n-1)}&\omega_n^{3(n-1)}&\cdots&\omega_n^{(n-1)(n-1)}\end{matrix} \right]\),

then the formula of DFT is \(\pmb{y}=\text{DFT}_n(\pmb a)\): \(V_n\pmb{a}=\pmb{y}\). Specifically, \(y_i=\sum_{j=0}^{n-1}[V_n]_{ij}a_j=\sum_{j=0}^{n-1}\omega_n^{ij}a_j\).

So, how can we achieve it in \(O(n\log n)\)?

Part 4: FFT

Like DC multiplication, we split the polynomial by parity: \(A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2)\), where \(A^{[0]}(x)=\sum_{i=0}^{n/2-1}a_{2i}x^i,A^{[1]}(x)=\sum_{i=0}^{n/2-1}a_{2i+1}x^i\).

Then, our evaluation of \(A(x)\) on \(\omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1}\) becomes

1. Divide-and-conquer: evaluating \(A^{[0]}(x)\) and \(A^{[1]}(x)\) on \((\omega_n^0)^2,(\omega_n^1)^2,\cdots,(\omega_n^{n-1})^2\).

By \(\text{Lemma 2}\), the list \((\omega_n^0)^2,(\omega_n^1)^2,\cdots,(\omega_n^{n-1})^2\) is exactly a repeated list of \(n/2\)-roots of unity (Why?)

So we can apply \(DFT_{n/2}(\pmb a^{[0]})=y^{[0]},DFT_{n/2}(\pmb a^{[1]})=\pmb y^{[1]}\). And the second step is

2. Combining the answers.

As \(\omega_n^{n/2}=e^{2\pi i (n/2)/n}=e^{\pi i}=-1\) (The beautiful Euler's formula!),

we have \(\omega_n^{k+n/2}=\omega_n^k\omega_n^{n/2}=-\omega_n^k\),

so \(y_i=y^{[0]}_i+\omega_n^i y^{[1]}_i,y_{i+n/2}=y^{[0]}_i-\omega_n^i y^{[1]}_i,\) for all \(i=0,1,\cdots,n/2-1\).

Specifically, when \(n=1\), \(\omega_1^0 a_0=a_0\) in the trivial case.

Let's calculate the time complexity

\(T(n)=2T(n/2)+f(n)\), in which \(f(n)=O(n)\) is the time used for combination.

Using Master Theorem with \(a=2,b=2,\log_ba=\log_2 2=1\), we have \(T(n)=O(n^{\log_ba}\log n)=O(n\log n)\). Whooo!

Part 5: Inverse DFT

Don't celebrate too soon, there is still interpolation. Awww

Since \(\pmb{y}=\text{DFT}_n(\pmb{a})=V_n\pmb{a}\), we have \(\pmb{a}=V_n^{-1}\pmb{y}\), written as \(\pmb{a}=\text{DFT}_n^{-1}(\pmb{y})\).

\(\text{Theorem. }\) For all \(i,j=0,1,\cdots,n-1\), we have \([V_n^{-1}]_{ij}=\omega_n^{-ij}/n\).

\(\textit{Proof. }\) We show that \(V_n^{-1}V_n=I_n\) the identity matrix:

\([V_n^{-1}V_n]_{ij}=\sum_{k=0}^{n-1}(\omega_n^{-ik}/n)\omega_n^{kj}=\frac{\sum_{k=0}^{n-1}\omega_n^{-ik}\omega_n^{kj}}{n}=\frac{\sum_{k=0}^{n-1}\omega_n^{(j-i)k}}{n}\)

If \(i=j\), then \(\frac{\sum_{k=0}^{n-1}\omega_n^0}{n}=n/n=1\). Otherwise, it is \(0/n=0\) by \(\text{Lemma 3}\). Therefore, \(I_n\) is formed. \(\square\)

Next, \(\pmb{a}=\text{DFT}_n^{-1}(\pmb{y})=V_n^{-1}\pmb{y}\), in which \(a_i=\sum_{j=0}^{n-1}[V_n^{-1}]_{ij}y_j=\sum_{j=0}^{n-1}(\omega_n^{-ij}/n)y_j=\frac{\sum_{j=0}^{n-1}\omega_n^{-ij}y_j}{n}\).

Let's compare: in DFT, \(y_i=\sum_{j=0}^{n-1}\omega_n^{ij}a_j\).

Therefore, we can convert DFT to IDFT by simply replacing \(\omega_n^k\) with \(\omega_n^{-k}\) and dividing the final answers by \(n\).

Part 6: Recursive Implementation

According to the previous text, we just need to modify the code of DC multiplication.

To save memory, we redistribute the coefficients of \(A^{[0]}\) to the left and \(A^{[1]}\) to the right.

In the code, o\(=\omega_n\), w\(=\omega_n^i\).

P.s Don't for get \(/n\) for IDFT. In the code, the +0.5 is used to improve accuracy for integer-coefficient FFT.

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int u;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,u=c&15;else k=u=0;
	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
	return k?u:-u;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx o,w,a[N],b[N],tmp[N],x,y;
int n,m,s;
bool iv;
void fft(Cpx*a,int n);
signed main(){
	Rd(n),Rd(m),s=1<<int(log2(n+m)+1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	fft(a,s),fft(b,s);
	Frn0(i,0,s)a[i]*=b[i];
	iv=1,fft(a,s);
	Frn1(i,0,n+m)wr(a[i].real()/s+0.5),Ps;
	exit(0);
}
void fft(Cpx*a,int n){
	if(n==1)return;
	int n2(n>>1);
	Frn0(i,0,n2)tmp[i]=a[i<<1],tmp[i+n2]=a[i<<1|1];
	copy(tmp,tmp+n,a),fft(a,n2),fft(a+n2,n2);
	o={cos(Pi/n2),(iv?-1:1)*sin(Pi/n2)},w=1;
	Frn0(i,0,n2)x=a[i],y=w*a[i+n2],a[i]=x+y,a[i+n2]=x-y,w*=o;
}

Time complexity: \(O(n\log n)\)

Memory complexity: \(O(n)\)

Results:

Not fully AC, as recursive implementation is not fast enough.

Part 6: Iterative Implementation

For \(n=\deg_A+1,m=\deg B+1\), let \(l=\lceil\log_2(n+m+1)\rceil\) and \(s=2^l\), then \(s\) is the "\(n\)" in previous parts.

Similarly, we redistribute the coefficients of \(A^{[0]}\) to the left and \(A^{[1]}\) to the right.

Observe the pattern of redistribution in each layer of recursion. Take \(s=8\) as an example:

0-> 0 1 2 3 4 5 6 7
1-> 0 2 4 6|1 3 5 7
2-> 0 4|2 6|1 5|3 7
end 0|4|2|6|1|5|3|7

Still confused? Write them in base-2:

0-> 000 001 010 011 100 101 110 111
1-> 000 010 100 110|001 011 101 111
2-> 000 100|010 110|001 101|011 111
end 000|100|010|110|001|101|011|111

The base-2 expressions are reversed in the last layer!

A hint of the proof: the redistribution is based on parity, which is equivalent to the last digit of base-2 expressions.

In the code, we use array \(r_{0..s-1}\) to store the reverse numbers.

Butterfly Operation

It is already written in the code of recursive implementation, but let's clarify that:

Still remember \(y_i=y^{[0]}_i+\omega_n^i y^{[1]}_i,y_{i+n/2}=y^{[0]}_i-\omega_n^i y^{[1]}_i,i=0,1,\cdots,n/2-1\)?

To save memory, we do not create the array \(\pmb y\), but the combination is done on the original location of the array \(\pmb a\).

After redistribution, we have \(a^{[0]}_i=a_i\) and \(a^{[1]}_i=a_{i+n/2}\).

Let \(x=a^{[0]}_i=a_i,y=\omega_n^i a^{[1]}_i=\omega_n^i a_{i+n/2}\),

then the result of DFT is simply \(a_i=x+y,a_{i+n/2}=x-y\)!

With Butterfly Operation, we just need to redistribute the coefficients according to \(r\), and then combine iteratively to implement FFT.

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int u;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,u=c&15;else k=u=0;
	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
	return k?u:-u;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
double const Pi(acos(-1));
typedef complex<double> Cpx;
#define N (2100000)
Cpx a[N],b[N],o,w,x,y;
int n,m,l,s,r[N];
void fft(Cpx*a,bool iv);
signed main(){
	Rd(n),Rd(m),s=1<<(l=log2(n+m)+1);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	fft(a,0),fft(b,0);
	Frn0(i,0,s)a[i]*=b[i];
	fft(a,1);
	Frn1(i,0,n+m)wr(a[i].real()+0.5),Ps;
	exit(0);
}
void fft(Cpx*a,bool iv){
	Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
	for(int i(2),i2(1);i<=s;i2=i,i<<=1){
		o={cos(Pi/i2),(iv?-1:1)*sin(Pi/i2)};
		for(int j(0);j<s;j+=i){
			w=1;
			Frn0(k,0,i2){
				x=a[j+k],y=w*a[j+k+i2];
				a[j+k]=x+y,a[j+k+i2]=x-y,w*=o;
			}
		}
	}
	if(iv)Frn0(i,0,s)a[i]/=s;
}

Time complexity: \(O(n\log n)\)

Memory complexity: \(O(n)\)

Results:

Celebrate


Extension: Number Theoretic Transform

Although FFT has excellent time complexity, inaccuracy will inevitably arise because of the use of complex numbers.

If the polynomial coefficients and results are non-negative integers in a certain range, NTT is a better choice on accuracy and speed.

Prerequisite knowledge:

  FFT absolutely

  Modular arithmetics basics

Primitive roots

Assume that the following calculations are in the context of \(\bmod P\), where \(P\) is a prime number.

For a positive integer \(g\), if the list of powers of \(g\) contains every positive integer \(<P\), then we call \(g\) a primitive root \(\bmod P\). (Digression: in Group Theory, the equivalence class of \(g\) in \(\Z_p\) is a generator of \(\Z_p^*\))

E.g For \(P=7\) and for all positive integers \(<P\), we calculate the possibilities of their powers.

1-> {1}
2-> {1,2,4}
3-> {1,2,3,4,5,6}
4-> {1,2,4}
5-> {1,2,3,4,5,6}
6-> {1,6}

Therefore, \(3,5\) are the primitive roots \(\bmod 7\).

In the code, we commonly use \(P=998244353,g=3\).

The special property of primitive root \(g\) is that its powers repeat with period \(P-1\).

E.g Let \(P=7,g=3\), then the powers of \(g\) (beginning with \(g^0\)) are:\(1,3,2,6,4,5,1,3,2,6,4,5,\cdots\).

This property is very similar to the roots of unity. If we take \(n=P-1\) and \(\omega_n=g\), then all three lemmas in the FFT part are satisfied.

However, to complete NTT, there is one last step.

The substitute for roots of unity

In FFT, we use \(n\)-th roots of unity, where \(n\) is a power of \(2\).

However, \(P-1\) is not necessarily \(n\). Hence, we cannot directly replace \(\omega_n\) with \(g\).

Now, as the powers of \(g\) have a period of \(P-1\),

if we take a factor \(k\) of \(P-1\), then the powers of \(g^k\) have a period of \(\frac{P-1}{k}\). (Why?)

This means that if we take \(k=\frac{P-1}{n}\), then the powers of \(g^k\) have a period of exactly \(n\).

But, how can we be sure that \(n\) is always a factor of \(P-1\)?

This is why we choose \(P=998244353\), as \(P-1=998244352=2^{23}\cdot 7\cdot 17\), with a high multiplicity of \(2\).

Therefore, \(g^{\frac{P-1}{n}}\) is just our substitute of \(\omega_n\).

In the code, we use \(g^{-1}=332748118\) and \(\cdot s^{-1}\) when doing IDFT. Make sure that you include \(\bmod P\) in every operation.

//This program is written by Brian Peng.
#pragma GCC optimize("Ofast","inline","no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define Rd(a) (a=read())
#define Gc(a) (a=getchar())
#define Pc(a) putchar(a)
int read(){
	register int u;register char c(getchar());register bool k;
	while(!isdigit(c)&&c^'-')if(Gc(c)==EOF)exit(0);
	if(c^'-')k=1,u=c&15;else k=u=0;
	while(isdigit(Gc(c)))u=(u<<1)+(u<<3)+(c&15);
	return k?u:-u;
}
void wr(register int a){
	if(a<0)Pc('-'),a=-a;
	if(a<=9)Pc(a|'0');
	else wr(a/10),Pc((a%10)|'0');
}
signed const INF(0x3f3f3f3f),NINF(0xc3c3c3c3);
long long const LINF(0x3f3f3f3f3f3f3f3fLL),LNINF(0xc3c3c3c3c3c3c3c3LL);
#define Ps Pc(' ')
#define Pe Pc('\n')
#define Frn0(i,a,b) for(register int i(a);i<(b);++i)
#define Frn1(i,a,b) for(register int i(a);i<=(b);++i)
#define Frn_(i,a,b) for(register int i(a);i>=(b);--i)
#define Mst(a,b) memset(a,b,sizeof(a))
#define File(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout)
#define P (998244353)
#define G (3)
#define Gi (332748118)
#define N (2100000)
int n,m,l,s,r[N],a[N],b[N],o,w,x,y,siv;
int fpw(int a,int p){return p?a>>1?(p&1?a:1)*fpw(a*a%P,p>>1)%P:a:1;}
void ntt(int*a,bool iv);
signed main(){
	Rd(n),Rd(m),siv=fpw(s=1<<(l=log2(n+m)+1),P-2);
	Frn1(i,0,n)Rd(a[i]);
	Frn1(i,0,m)Rd(b[i]);
	Frn0(i,0,s)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	ntt(a,0),ntt(b,0);
	Frn0(i,0,s)a[i]=a[i]*b[i]%P;
	ntt(a,1);
	Frn1(i,0,n+m)wr(a[i]),Ps;
	exit(0);
}
void ntt(int*a,bool iv){
	Frn0(i,0,s)if(i<r[i])swap(a[i],a[r[i]]);
	for(int i(2),i2(1);i<=s;i2=i,i<<=1){
		o=fpw(iv?Gi:G,(P-1)/i);
		for(int j(0);j<s;j+=i){
			w=1;
			Frn0(k,0,i2){
				x=a[j+k],y=w*a[j+k+i2]%P;
				a[j+k]=(x+y)%P,a[j+k+i2]=(x-y+P)%P,w=w*o%P;
			}
		}
	}
	if(iv)Frn0(i,0,s)a[i]=a[i]*siv%P;
}

Time complexity: \(O(n\log n)\)

Memory complexity: \(O(n)\)

Results

No significant improvement in time, but halved the memory cost as int instead of complex is used.


The End:

Translating is sooooo time-consuming...

Another year with Cnblogs! Happy new year!

Thanks for your support! ありがとう!


Reference:

Introduction to Algorithms

自为风月马前卒:快速傅里叶变换(FFT)详解

自为风月马前卒:快速数论变换(NTT)小结

posted @ 2022-01-04 19:06  BrianPeng  阅读(213)  评论(0编辑  收藏  举报