[笔记] [题解]多项式学习
[笔记] [题解]多项式学习
\(\mathcal{FFT}\)
参考博客:\(\textit{litble}\)学长的校内博客
前置知识
在接下来的讲解中可能会用到一些高中数学知识,现在先稍微讲解一下(主要是我不会啊)
虚数&复数
定义
虚数的基本单位:\(i=\sqrt{-1}\)
复数:一个复数\(x\)可以表示为\(x = a+bi\)
运算
加减:\(\large (a + bi)+(c + di)=(a+c)+(b+d)i\)
乘除:\(\large (a+bi)(c+di)=ac+bci+adi-bd=(ac-bd)+(bc+ad)i\)
其他形式
复数的三角形式如图
偷来的图
其中\(\theta\)是复数的辐角,可以表示成\(\theta+2k\pi\)的形式。
那么我们可以得出复数乘除运算的三角式:
\(\large r_1(\cos\theta_1+i\ \sin\theta_1)r_2(\cos\theta_2+i\ \sin\theta_2)=r_1r_2(\cos(\theta_1+\theta_2)+\sin(\theta_1+\theta_2))\)
也就是说复数积的模等于各个复数的模的积,乘积的辐角等于各个复数辐角的和。
单位根
\(n\)次单位根就是满足\(\omega^n=1\)的复数
由复数的三角式乘除运算法则可以知道有\(n\)个这样的复数,它们分布于平面的单位圆上,并且平分这个单位圆。
\(n\)次单位根是:\(\large e^{\frac{2\pi ki}{n}},k=0,1,2,\dots,n-1\)
还有一个接下来的变形中会用到的公式:欧拉公式:\(e^{\theta i}=\text{cos}\theta+i\ \text{sin}\theta\)
因此就可以得到\(n\)次单位根的算术表示法:记\(\omega_n=e^{\frac{2\pi i}{n}}\)
总结一下单位根的性质:
多项式乘法
给定多项式\(\large A(x)=\sum^n_{i=0}a_ix^i\)和\(\large B(x)=\sum^n_{i=0}b_ix^i\),则它们的积是\(\large C(x)=\sum_{j+k=i,0\leq j,k\leq n}\ a_jb_kx^i\)
\(\mathcal{FFT}\)具体知识
折半引理
对于\(n>0\)且\(n\)为奇数有:
证明:\(\large (\omega_n^{k+\frac{n}{2}})^2=\omega_n^{2k+n}=\omega^{2k}_n\omega^n_n=(\omega^k_n)^2=\omega^k_{n/2}\)。可以参照上面的欧拉公式。
快速傅里叶变换
就是多项式快速转化为点值表示法。
首先进行奇偶性划分:
所以就把\(A(x)\)变成了\(A_0(x^2)+xA_1(x^2)\)
同时用复数\(\omega^k_n\)来加速。(在下文蝶形变换有差不多的解释,看不懂的没关系)
可以得到\(\large A(\omega^k_n)=A_0((\omega^k_n)^2)+\omega^k_n A_1((\omega^k_n)^2)------------(1)\)
由于\(\large (e^{\frac{2\pi i}{n}})^{2k}=(e^{\frac{4\pi i}{n}})^k\)
所以\(\large A(\omega^k_n)=A_0(\omega^k_{n/2})+\omega^k_{n/2}\)
根据折半引理,\(\large \omega_n^{k+\frac{n}{2}}=(e^{\frac{2\pi i}{n}})^{k+\frac{n}{2}}=(e^{\frac{2\pi i}{n}})^ke^{\pi i}=\omega^k_n\omega_n^{\frac{n}{2}}=-\omega^k_n\)
可以得到\(\large A(\omega_n^{k+n/2})=A_0(\omega^k_{n/2})-\omega^k_nA_1(\omega^k_{n/2})-----------(2)\)
当\(\large k\in[0,\frac{n}{2}-1]\)时,\(\large k+\frac{n}{2}\in[\frac{n}{2},n-1]\)
这样利用分支来实现的复杂度是\(O(nlog_2n)\)
蝶形变换
这种算法的英文名称是\(Cooley-Tukey\)算法。
假设现在有一个\(n-1\)次多项式\(\large A(x)=\sum^{n-1}_{i=0}a_ix^i\)(方便起见,设\(n=2^m,m\in\Z\))
将\(n\)个\(n\)次单位根\(\omega^0_n,\omega^1_n,\dots,\omega^{n-1}_n\)带入多项式\(A(x)\)将其转换成点值表达
接下来把每一项进行奇偶分类
前面有提到\(\large \omega^2_n=(e^{\frac{2\pi i}{n}})=e^{\frac{2\pi i}{n/2}}=\omega_{\frac{n}{2}}\),也就是说要带入的值经过平方之后变少了一半,原因是单位根把单位元平分,那么肯定具有对称性,所以说肯定有一正一负两个,平方之后自然就相等了。
也就是说当\(k<\frac{n}{2}\)时
这样我们带入的值也就变成了\(\large 1,\omega_{\frac{n}{2}}^1,\omega_{\frac{n}{2}}^2,\dots,\omega_{\frac{n}{2}}^{\frac{n}{2}-1},\)也就是把单位圆上的单位根一次代入,这样的复杂度就是\(\large O(nlog_2n)\)
举一个具体一点的例子来描述一下奇偶分类的具体过程:
初始的系数:\(\large \omega_n^0\omega_n^1\omega_n^2\omega_n^3\omega_n^4\omega_n^5\omega_n^6\omega_n^7\)
一次变换后:\(\large \omega^0_n\omega^2_n\omega^4_n\omega^6_n\omega^1_n\omega^3_n\omega^5_n\omega^7_n\)
两次变换后:\(\large \omega^0_n\omega^4_n\omega^2_n\omega^6_n\omega^1_n\omega^5_n\omega^3_n\omega^7_n\)
傅里叶逆变换
我目前不是很懂,不过过程是:把原来傅里叶变换中\(\omega_n^i\)换成\(\omega_n^{-i}\),然后做一次傅里叶变换,之后把得到的结果除以\(n\)即可。
代码实现
这个是这个题
#include <bits/stdc++.h>
using namespace std;
const int N = 3000010;
const double pi = 3.1415926535897384626;
struct complex_num{
double r,i;
}a[N],b[N];
int n,m,len,rev[N];
complex_num operator + (complex_num a,complex_num b){
return (complex_num){a.r + b.r,a.i + b.i};
}
complex_num operator - (complex_num a,complex_num b){
return (complex_num){a.r - b.r,a.i - b.i};
}
complex_num operator * (complex_num a,complex_num b){
return (complex_num){a.r * b.r - a.i * b.i,a.i * b.r + a.r * b.i};
}
complex_num operator / (complex_num a,double c){
return (complex_num){a.r / c,a.i / c};
}
void FFT(complex_num *a,int x){
for(int i = 0;i < n;i++)
if(i < rev[i])
swap(a[i],a[rev[i]]);//防止一个元素交换两次回到它原来的位置
for(int i = 1;i < n;i <<= 1){
complex_num wn = (complex_num){cos(pi / i),x * sin(pi / i)};
for(int j = 0;j < n;j += (i << 1)){
complex_num w = (complex_num){1,0},tmp1,tmp2;
for(int k = 0;k < i;k++,w = w * wn){
tmp1 = a[j + k],tmp2 = w * a[j + k + i];
a[j + k] = tmp1 + tmp2;a[j + k + i] = tmp1 - tmp2;
}
}
}
if(x == -1)for(int i = 0;i < n;i++)a[i] = a[i] / n;
}
int main(){
scanf("%d%d",&n,&m);
for(int i = 0;i <= n;i++)scanf("%lf",&a[i].r);
for(int i = 0;i <= m;i++)scanf("%lf",&b[i].r);
m = n + m;
for(n = 1;n <= m;n <<= 1)len++;
for(int i = 0;i < n;i++)rev[i] = (rev[i >> 1] >> 1) | (i & 1) << (len - 1);
FFT(a,1);FFT(b,1);
for(int i = 0;i <= n;i++)a[i] = a[i] * b[i];
FFT(a,-1);
for(int i = 0;i <= m;i++)printf("%d ",(int)(a[i].r + 0.5));
return 0;
}
还有这个题
#include <bits/stdc++.h>
using namespace std;
const long long N = 3000010;
const double pi = 3.1415926535897384626;
struct com{
double r,i;
}a[N],b[N];
long long ans[N];
long long n,m,len,rev[N];
char s[N];
com operator + (com a,com b){
return (com){a.r + b.r,a.i + b.i};
}
com operator - (com a,com b){
return (com){a.r - b.r,a.i - b.i};
}
com operator * (com a,com b){
return (com){a.r * b.r - a.i * b.i,a.r * b.i + b.r * a.i};
}
com operator / (com a,double c){
return (com){a.r / c,a.i / c};
}
void FFT(com *a,long long x){
for(long long i = 0;i < n;i++)if(i < rev[i])swap(a[i],a[rev[i]]);//防止交换两次,等同于没有交换
for(long long i = 1;i < n;i <<= 1){//i是准备合并的序列的长度的一半
com wn = (com){cos(pi / i),x * sin(pi / i)};//单位根
for(long long j = 0;j < n;j += (i << 1)){//j是合并到了哪一位
com w = (com){1,0},tmp1,tmp2;
for(long long k = 0;k < i;k++,w = w * wn){//只扫左半部分,同时得到右半部分的答案(蝴蝶变换)
tmp1 = a[j + k],tmp2 = w * a[j + k + i];
a[j + k] = tmp1 + tmp2;//对应上面快速傅里叶变换的(1)
a[j + k + i] = tmp1 - tmp2;//对应上面快速傅里叶变换的(2)
}
}
}
if(x == -1)for(long long i = 0;i < n;i++)a[i] = a[i] / n;
}
signed main(){
scanf("%s",&s);
n = strlen(s);
for(long long i = n - 1;i >= 0;i--)a[n - i - 1].r = s[i] - '0';//下标从0开始
scanf("%s",&s);
m = strlen(s);
for(long long i = m - 1;i >= 0;i--)b[m - i - 1].r = s[i] - '0';
m = n + m;
for(n = 1;n <= m;n <<= 1)len++;
for(long long i = 0;i < n;i++)rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
FFT(a,1);FFT(b,1);
for(long long i = 0;i <= n;i++)a[i] = a[i] * b[i];
FFT(a,-1);
for(long long i = 0;i <= m;i++)ans[i] = (long long)(a[i].r + 0.5);
len = 0;
for(long long i = 0;i <= m;i++){//算出来的是系数,可能大于10,要进位
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
len++;
}
while(ans[len] >= 10){
ans[++len] += ans[len - 1] / 10;
ans[len - 1] %= 10;
}
while(ans[len] == 0)len--;
for(long long i = len;i >= 0;i--)printf("%d",ans[i]);
printf("\n");
return 0;
}
一些对于程序的解释
在第二个程序中,经过第\(46\)行的操作,可以得到\(2^k\ge2n\)且\(2^{k-1}<2n\),\(n=2^k\)。这样是为了使无论将单位圆分成几份都会是整数份数。
第\(47\)行的程序的作用是:因为我们需要进行奇偶分类,这里就有一个性质,比如说现在要算出下标为\(4\)的元素在奇偶分类之后排在哪一位,那么我们先表示出\(4\)的二进制数\(0100\),再把这个二进制数颠倒得到\(0010\)对应的十进制数就是下标为\(4\)的元素的位置,位于第\(2\)个。可以参见上面奇偶分类系数的变换。
在\(\mathcal{FFT}\)函数中,三重循环的三个循环变量\(i,j,k\)分别代表:把单位圆分成几份,从第几个单位根开始在单位圆上转,当前计算到了哪一个单位根。
\(\mathcal{NTT}\)
在实现\(\mathcal{FFT}\)的时候我们会发现其实在计算过程中是有精度损失的,因为我们利用\(\omega_n\)实现了折半引理。有没有什么整数可以用来代替\(\omega_n\)呢?
原根可以取而代之。定义\(P\)的原根为满足\(\large g^{\phi(P)}\equiv1(\mod P)\)的整数\(g\)。
我们用\(\large g^{\frac{\phi(P)}{n}}\)代替\(n\)次单位根进行计算,因为\(P\)是质数,所以\(\phi(P)=P-1\),有要求\(\large \frac{\phi(P)}{n}\)为整数,\(n\)还是\(2\)的整数次幂,所以要求\(\large P=k*2^q+1\),其中\(2^q\ge n\)。
怎么求原根呢?如果题目没有给出模数,就要用\(\mathcal{BSGS}\),如果\(P\)不是质数就要用中国剩余定理合并。
另一种定义是:若有\(g\)使得\(g^i\mod P\)的结果两两不同,\(P\)为质数,且\(g\in[2,p-1]i,i\in[1,p-1]\),那么称\(g\)是\(P\)的原根。比如说\(998244353\)的原根就是\(3\)。
代码的坑还没填。。。
拉格朗日插值
先放一道例题
题目大意
给出\(n\)个点\(P_i(x_i,y_i)\),将过这\(n\)个点的最多\(n-1\)次的多项式记为\(f(x)\),求\(f(k)\)的值。
拉格朗日插值
设我们现在有给定的\(n+ 1\)个点,分别是\((x_0,y_0),(x_1,y_1),\dots,(x_n,y_n)\)
则拉格朗日基本多项式为
我们可以发现\(\large \ell_j(x_j)=1\),并且\(\large \ell_j(x_i)=0,\forall i\ne j\),也就是说\(\large \ell_j(x_i)\)函数的作用就是让函数的返回值只有\(0\)或\(1\),而且在传入\(x_j\)的时候返回\(1\),其余时候返回\(0\)。
接着就是\(n\)次多项式
观察上式,我们可以发现\(\large P(x_i)=y_i\),也就是经过了给定的\(n+1\)个点。
整合上面的两个公式得到最终的拉格朗日插值法的公式:
对于例题而言只要求出\(f(k)\)的值即可。
代码
#include <bits/stdc++.h>
using namespace std;
const long long mod = 998244353;
long long x[20010],y[20010],ans,tmp1,tmp2;
inline long long qpow(long long x,long long y){
long long res = 1ll;
while(y){
if(y & 1){
res = res * x;
res %= mod;
}
x = x * x;x %= mod;
y >>= 1;
}
return res;
}
long long n,k;
int main(){
scanf("%lld%lld",&n,&k);
for(int i = 1;i <= n;i++)scanf("%lld%lld",&x[i],&y[i]);
for(int i = 1;i <= n;i++){
tmp1 = y[i] % mod;
tmp2 = 1ll;
for(int j = 1;j <= n;j++){
if(i != j)
tmp1 = tmp1 * (k - x[j]) % mod,tmp2 = tmp2 * (x[i] - x[j]) % mod;
}
ans += tmp1 * qpow(tmp2,mod - 2) % mod;
}
printf("%lld\n",(ans % mod + mod) % mod);
return 0;
}
多项式操作
多项式求逆
定义
对于一个多项式\(A(x)\)如果存在\(B(x)\)满足\(B\)的次数不大于\(A\)并且
那么称\(B(x)\)为\(A(x)\)在\(\mod x^n\)意义下的逆元,记作\(A^{-1}(x)\)
\(\mod x^n\)是忽略次数\(\ge n\)的项。
求解方法
假设\(A(x)\)在\(\mod x^{\frac{n}{2}}\)的意义下的逆元为\(B_0(x)\),那么就有
再把上面两个式子做差,得到:
再进行化简:
左右两边同时平方:
多项式长度翻倍后上式依然成立:
左右两边同时乘以\(A(x)\)并且由于\(A(x)B(x)\equiv 1(\mod x^n)\),所以可以化简:
再经过移项就得到了最终的结果:
这个式子可以倍增或者递归来求。
代码
#include<bits/stdc++.h>
using namespace std;
const int mod = 998244353,G = 3,N = 2100000;
int n;
int a[N],b[N],c[N],rev[N];
inline int qpow(int x,int y) {
int res = 1;
while(y){
if(y & 1){
res = 1LL * res * x % mod;
}
x = 1LL * x * x % mod;
y >>= 1;
}
return res;
}
inline void NTT(int *a,int n,int x) {
for(int i = 0;i < n;i++)
if(i < rev[i])
swap(a[i],a[rev[i]]);
for(int i = 1;i < n;i <<= 1) {
int gn = qpow(G,(mod - 1) / (i << 1));
for(int j = 0;j < n;j += (i << 1)) {
int t1,t2,g = 1;
for(int k = 0;k < i;k++,g = 1LL * g * gn % mod) {
t1 = a[j + k],t2 = 1LL * g * a[j + k + i] % mod;
a[j + k] = (t1 + t2) % mod,a[j + k + i] = (t1 - t2 + mod) % mod;
}
}
}
if(x == 1)return;
int inv = qpow(n,mod - 2);
reverse(a + 1,a + n);
for(int i = 0;i < n;i++) a[i] = 1LL * a[i] * inv % mod;
}
void work(int deg,int *a,int *b) {
if(deg == 1){
b[0] = qpow(a[0],mod - 2);
return;
}
work((deg + 1) >> 1,a,b);
int len = 0,rhs = 1;
while(rhs < (deg << 1))rhs <<= 1,len++;
for(int i = 1;i < rhs;i++)rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
for(int i = 0;i < deg;i++)c[i] = a[i];
for(int i = deg;i < rhs;i++)c[i] = 0;
NTT(c,rhs,1),NTT(b,rhs,1);
for(int i = 0;i < rhs;i++)
b[i] = 1LL * (2 - 1LL * c[i] * b[i] % mod + mod) % mod * b[i] % mod;
NTT(b,rhs,-1);
for(int i = deg;i < rhs;i++)b[i] = 0;
}
int main(){
scanf("%d",&n);
for(int i = 0;i < n;i++)scanf("%d",&a[i]);
work(n,a,b);
for(int i = 0;i < n;i++)printf("%d ",b[i]);
return 0;
}
未完待续...