多项式全家桶(自用)
多项式全家桶
FFT
namespace FFT{
const double eps = 0.49,pi = acos(-1.0);
typedef complex<double> comp;
vector<int> rev;vector<comp> ans;
void FFT(int n,int inv){
for(int i = 0;i < n;i++)if(i < rev[i])swap(ans[i],ans[rev[i]]);
for(int i = 1;i < n;i <<= 1){
comp wn(cos(pi / i),inv * sin(pi / i));
for(int j = 0;j < n;j += (i << 1)){
comp w0(1, 0);
for(int k = 0;k < i;k++,w0 *= wn){
comp x = ans[j + k], y = w0 * ans[i + j + k];
ans[j + k] = x + y;ans[i + j + k] = x - y;
}
}
}
}
/*
input a[0,n], b[0,m],return ans[0,n + m] in O(2^k times k)(2^{k-1}<n+m<=2^k)
which meet forall i in[0,n + m],ans_i= sum_{j=0}^{i}a_j times b_{i - j}
*/
vector<int> convolute(vector<int> a,vector<int> b){
int n = a.size() - 1, m = b.size() - 1;
int len = (1 << max((int)ceil(log2(n + m)),1));rev.clear();ans.clear();
ans.resize(len + 10);rev.resize(len + 10);
for(int i = 0;i <= n;i++){ans[i].real(a[i]);}
for(int i = 0;i <= m;i++){ans[i].imag(b[i]);}
for(int i = 0;i < len;i++){rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (max((int)ceil(log2(n + m)),1) - 1));}
FFT(len,1);
for(int i = 0;i < len;i++)ans[i] = ans[i] * ans[i];
FFT(len,-1);
for(int i = 0;i <= n + m;i++){rev[i] = round(ans[i].imag() / 2 / len + eps);}
rev.resize(n + m + 1);
return rev;
}
};
模意义下全家桶
目前封装较差,需要进行卡常,优化待补。
namespace polymod{
//NTT
vector<int> rev;
int mod, g, gi;
vector<int> pwwwg, pwwwig;
vector<int> invvv;
int qpow(int x,int a){
if(x == g && a <= 1e6){return pwwwg[a];}
if(x == gi && a <= 1e6){return pwwwig[a];}
if(a == mod - 2 && x <= invvv.size()){return invvv[x];}
int res = 1;
while(a){
if(a & 1)res = res * x % mod;
x = x * x % mod;a >>= 1;
}
return res;
}
void init(int md){
mod = md;
if(mod == 998244353 || mod == 1004535809){g = 3;}
else {cerr << "No Root !";}
gi = qpow(g,mod - 2);
pwwwg.push_back(1);pwwwig.push_back(1);
for(int i = 1;i <= 1e6;i++){pwwwg.push_back(pwwwg.back() * g % mod);pwwwig.push_back(pwwwig.back() * gi % mod);}
invvv.push_back(0);invvv.push_back(1);
for(int i = 2;i <= 1e6;i++){invvv.push_back((mod - mod / i) * invvv[mod % i] % mod);}
}
void NTT(vector<int> &ans,int n,int inv){
for(int i = 0;i < n;i++)if(i < rev[i])swap(ans[i],ans[rev[i]]);
for(int i = 1;i < n;i <<= 1){
int wn = qpow((inv == 1 ? g : gi),(mod - 1) / (i << 1));
for(int j = 0;j < n;j += (i << 1)){
int w0 = 1;
for(int k = 0;k < i;k++,w0 = w0 * wn % mod){
int x = ans[j + k], y = w0 * ans[i + j + k] % mod;
ans[j + k] = (x + y) % mod;ans[i + j + k] = (x - y + mod) % mod;
}
}
}
}
/*
input a[0,n], b[0,m],return ans[0,n + m] in O(2^k times k)(2^{k-1}<n+m<=2^k)
which meet forall i in[0,n + m],ans_i= sum_{j=0}^{i}a_j times b_{i - j}(mod p)
*/
vector<int> convolute(vector<int> a,vector<int> b){
gi = qpow(g,mod - 2);
int n = a.size() - 1, m = b.size() - 1;
int len = (1 << max((int)ceil(log2(n + m)),1ll));rev.clear();
rev.resize(len + 10);a.resize(len + 10);b.resize(len + 10);
for(int i = 0;i < len;i++){rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (max((int)ceil(log2(n + m)),1ll) - 1));}
NTT(a,len,1);NTT(b,len,1);
for(int i = 0;i < len;i++)a[i] = a[i] * b[i] % mod;
NTT(a,len,-1);
int inv = qpow(len,mod - 2);
for(int i = 0;i <= n + m;i++){rev[i] = a[i] * inv % mod;}
rev.resize(n + m + 1);
return rev;
}
//polyinv
vector<int> invb, inva ,invc;int len, L;
void getinv(int n){
if(n == 1){invb[0] = qpow(inva[0],mod - 2);return;}
getinv((n + 1) >> 1);
len = 1;L = 0;while(len < (n << 1))len <<= 1,L++;
for(int i = 0;i < len;i++){rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));}
for(int i = 0;i < n;i++){invc[i] = inva[i];}for(int i = n;i < len;i++)invc[i] = 0;
NTT(invc,len,1);NTT(invb,len,1);
for(int i = 0;i < len;i++){invb[i] = (2 - invb[i] * invc[i] % mod + mod) % mod * invb[i] % mod;}
NTT(invb,len,-1);int inv = qpow(len,mod - 2);
for(int i = 0;i < n;i++)invb[i] = (invb[i] * inv) % mod;
for(int i = n;i < len;i++)invb[i] = 0;
}
vector<int> polyinv(vector<int> arr,int n = -1){
inva.clear();invb.clear();invc.clear();
inva = arr;if(n == -1)n = inva.size();int mxn = n * 5;
rev.resize(mxn);inva.resize(mxn);invb.resize(mxn);invc.resize(mxn);
getinv(n);invb.resize(n);
return invb;
}
//polyln
typedef vector<int> vi;
vi lna, lnb, A, B;
void Derive(vi A,vi &B,int len){for(int i = 1;i < len;i++)B[i - 1] = i * A[i] % mod;B[len - 1] = 0;}
void Integrate(vi A,vi &B,int len){for(int i = 1;i < len;i++)B[i] = A[i - 1] * qpow(i,mod - 2) % mod;B[0] = 0;}
void getln(vi &f,vi &g,int n){
Derive(f,A,n);B = polyinv(f,n);
len = 1;L = 0;while(len < (n << 1))len <<= 1,L++;
for(int i = 0;i < len;i++){rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));}
B.resize(len + 10);
NTT(A,len,1);NTT(B,len,1);
for(int i = 0;i < len;i++)A[i] = A[i] * B[i] % mod;
NTT(A,len,-1);int inv = qpow(len,mod - 2);
for(int i = 0;i < len;i++){A[i] = A[i] * inv % mod;}
Integrate(A,g,n);
}
vi polyln(vi arr,int n = -1){
lna.clear();lnb.clear();A.clear();B.clear();
lna = arr;if(n == -1)n = lna.size();
int mxn = n * 5;rev.resize(mxn);
lna.resize(mxn);lnb.resize(mxn);A.resize(mxn);B.resize(mxn);
int m = 1;for(;m <= n;m <<= 1);getln(lna,lnb,m);
lnb.resize(n);return lnb;
}
//polyexp
vi expa, expb, lnnb,a1;
void getexp(vi a,vi &b,int n){
if(n == 1){b[0] = 1;return;}
getexp(a,b,(n + 1) >> 1);
lnnb = polyln(b, n);
len = 1;L = 0;while(len < (n << 1))len <<= 1,L++;
for(int i = 0;i < len;i++){rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));}
a1.clear();a1.resize(len);
for(int i = 0;i < n;i++)a1[i] = ((a[i] - lnnb[i]) % mod + mod) % mod;for(int i = n;i < len;i++)a1[i] = 0;
a1[0]++;NTT(b,len,1);NTT(a1,len,1);
for(int i = 0;i < len;i++)b[i] = b[i] * a1[i] % mod;
NTT(b,len,-1);int inv = qpow(len,mod - 2);
for(int i = 0;i < n;i++)b[i] = b[i] * inv % mod;for(int i = n;i < len;i++)b[i] = 0;
}
vi polyexp(vi arr,int n = -1){
expa.clear();expb.clear();a1.clear();lnnb.clear();
expa = arr;
if(n == -1)n = expa.size();
int mxn = n * 5;expa = arr;
expa.resize(mxn);expb.resize(mxn);a1.resize(mxn);lnnb.resize(mxn);
getexp(expa,expb,n);expb.resize(n);
return expb;
}
};

浙公网安备 33010602011771号