洛谷P5205 【模板】多项式开根

https://www.luogu.org/problemnew/show/P5205

按道理说,多项式开根可以有多个解(根据常数项不同有不同的解)。此题只需要输出常数项为1的解(题面漏了)

首先,可以直接多项式快速幂做(2对998244353的逆元)次幂(直接做只能在输入常数项为1时)(我不是很懂为什么能起效,不过的确能AC)

版本1:基于版本1

  1 #prag\
  2 ma GCC optimize(2)
  3 #include<cstdio>
  4 #include<algorithm>
  5 #include<cstring>
  6 #include<vector>
  7 #include<cmath>
  8 using namespace std;
  9 #define fi first
 10 #define se second
 11 #define mp make_pair
 12 #define pb push_back
 13 typedef long long ll;
 14 typedef unsigned long long ull;
 15 const int md=998244353;
 16 const int N=262144;
 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
 18 inline int del(int a,int b)
 19 {
 20     a-=b;
 21     return a<0?a+md:a;
 22 }
 23 int rev[N];
 24 void init(int len)
 25 {
 26     int bit=0,i;
 27     while((1<<(bit+1))<=len)    ++bit;
 28     for(i=1;i<len;++i)
 29         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
 30 }
 31 ull poww(ull a,ull b)
 32 {
 33     ull ans=1;
 34     for(;b;b>>=1,a=a*a%md)
 35         if(b&1)
 36             ans=ans*a%md;
 37     return ans;
 38 }
 39 int inv[300011];
 40 void dft(int *a,int len,int idx)//要求len为2的幂
 41 {
 42     int i,j,k,t1,t2;ull wn,wnk;
 43     for(i=0;i<len;++i)
 44         if(i<rev[i])
 45             swap(a[i],a[rev[i]]);
 46     for(i=1;i<len;i<<=1)
 47     {
 48         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
 49         for(j=0;j<len;j+=(i<<1))
 50         {
 51             wnk=1;
 52             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
 53             {
 54                 t1=a[k];t2=a[k+i]*wnk%md;
 55                 a[k]+=t2;
 56                 (a[k]>=md)&&(a[k]-=md);
 57                 a[k+i]=t1-t2;
 58                 (a[k+i]<0)&&(a[k+i]+=md);
 59             }
 60         }
 61     }
 62     if(idx==-1)
 63     {
 64         ull ilen=inv[len];
 65         for(i=0;i<len;++i)
 66             a[i]=a[i]*ilen%md;
 67     }
 68 }
 69 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂
 70 {
 71     static int t1[N],t2[N];
 72     g[0]=poww(f[0],md-2);
 73     for(int i=2,j;i<=len;i<<=1)
 74     {
 75         memcpy(t1,f,sizeof(int)*i);
 76         memcpy(t2,g,sizeof(int)*(i>>1));
 77         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
 78         init(i);
 79         dft(t1,i,1);dft(t2,i,1);
 80         for(j=0;j<i;++j)
 81             t1[j]=ull(t1[j])*t2[j]%md;
 82         dft(t1,i,-1);
 83         for(j=0;j<(i>>1);++j)
 84             t1[j]=t1[j+(i>>1)];
 85         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
 86         dft(t1,i,1);
 87         for(j=0;j<i;++j)
 88             t1[j]=ull(t1[j])*t2[j]%md;
 89         dft(t1,i,-1);
 90         for(j=i>>1;j<i;++j)
 91             g[j]=md-t1[j-(i>>1)];
 92     }
 93 }
 94 inline void p_de(int *f,int len)//derivative求导;f=f'
 95 {
 96     for(int i=0;i<len-1;++i)
 97         f[i]=ull(i+1)*f[i+1]%md;
 98     f[len-1]=0;
 99 }
100 inline void p_in(int *f,int len)//integral积分;f=?f
101 {
102     for(int i=len-1;i>=1;--i)
103         f[i]=ull(f[i-1])*inv[i]%md;
104     f[0]=0;
105 }
106 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1
107 {
108     static int t3[N];
109     p_inv(f,t3,len);p_de(f,len);
110     init(len<<1);
111     dft(f,len<<1,1);dft(t3,len<<1,1);
112     for(int i=0;i<(len<<1);++i)
113         f[i]=ull(f[i])*t3[i]%md;
114     dft(f,len<<1,-1);p_in(f,len);
115 }
116 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0
117 {
118     static int t1[N],t2[N];
119     g[0]=1;
120     for(int i=2,j;i<=len;i<<=1)
121     {
122         memcpy(t1,g,sizeof(int)*(i>>1));
123         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
124         p_ln(t1,i);
125         for(j=0;j<(i>>1);++j)
126             t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]);
127         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
128         init(i);
129         dft(t1,i,1);
130         memcpy(t2,g,sizeof(int)*(i>>1));
131         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
132         dft(t2,i,1);
133         for(j=0;j<i;++j)
134             t1[j]=ull(t1[j])*t2[j]%md;
135         dft(t1,i,-1);
136         for(j=i>>1;j<i;++j)
137             g[j]=t1[j-(i>>1)];
138     }
139 }
140 inline void p_pow_1(int *f,int *g,int len,int b)//要求len为2的幂,常数项为1
141 {
142     p_ln(f,len);
143     for(int i=0;i<len;++i)
144         f[i]=ull(f[i])*b%md;
145     p_exp(f,g,len);
146 }
147 void p_pow(int *f,int *g,int len,int b)//g=f^b;要求len为2的幂
148 {
149     int i;ll p=-1;
150     for(i=0;i<len;++i)
151         if(f[i])
152         {
153             p=i;
154             break;
155         }
156     if(p==-1)    return;
157     for(i=0;i<len-p;++i)
158         f[i]=f[i+p];
159     memset(f+len-p,0,sizeof(int)*p);
160     int t=f[0],t1=poww(t,md-2),t2=poww(t,b);
161     for(i=0;i<len;++i)
162         f[i]=ull(f[i])*t1%md;
163     p_pow_1(f,g,len,b);
164     for(i=0;i<len;++i)
165         g[i]=ull(g[i])*t2%md;
166     p*=b;
167     for(i=len-1;i>=p;--i)
168         g[i]=g[i-p];
169     memset(g,0,sizeof(int)*min(ll(len),p));
170 }
171 int a[N],b[N];
172 int n,n1;
173 int main()
174 {
175     int i,t;
176     inv[1]=1;
177     for(i=2;i<=300000;++i)
178         inv[i]=ull(md-md/i)*inv[md%i]%md;
179     scanf("%d",&n);n1=n;
180     for(i=0;i<n;++i)
181         scanf("%d",a+i);
182     for(t=1;t<n;t<<=1);
183     n=t;
184     p_pow(a,b,n,499122177);
185     for(i=0;i<n1;++i)
186         printf("%d ",b[i]);
187     return 0;
188 }
View Code

也可以直接牛顿迭代做。设$g(f(x))=f(x)^2-A(x)$

$f(x)=f_0(x)-\frac{f_0(x)^2-A(x)}{2f_0(x)}=\frac{A(x)}{2f_0(x)}+\frac{f_0(x)}{2}$

版本2:基于版本2

  1 #prag\
  2 ma GCC optimize(2)
  3 #include<cstdio>
  4 #include<algorithm>
  5 #include<cstring>
  6 #include<vector>
  7 #include<cmath>
  8 using namespace std;
  9 #define fi first
 10 #define se second
 11 #define mp make_pair
 12 #define pb push_back
 13 typedef long long ll;
 14 typedef unsigned long long ull;
 15 const int md=998244353;
 16 const int N=262144;
 17 #define addto(a,b) ((a)+=(b),((a)>=md)&&((a)-=md))
 18 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
 19 inline int del(int a,int b)
 20 {
 21     a-=b;
 22     return a<0?a+md:a;
 23 }
 24 int rev[N];
 25 void init(int len)
 26 {
 27     int bit=0,i;
 28     while((1<<(bit+1))<=len)    ++bit;
 29     for(i=1;i<len;++i)
 30         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
 31 }
 32 ull poww(ull a,ull b)
 33 {
 34     ull ans=1;
 35     for(;b;b>>=1,a=a*a%md)
 36         if(b&1)
 37             ans=ans*a%md;
 38     return ans;
 39 }
 40 int inv[300011];
 41 void dft(int *a,int len,int idx)//要求len为2的幂
 42 {
 43     int i,j,k,t1,t2;ull wn,wnk;
 44     for(i=0;i<len;++i)
 45         if(i<rev[i])
 46             swap(a[i],a[rev[i]]);
 47     for(i=1;i<len;i<<=1)
 48     {
 49         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
 50         for(j=0;j<len;j+=(i<<1))
 51         {
 52             wnk=1;
 53             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
 54             {
 55                 t1=a[k];t2=a[k+i]*wnk%md;
 56                 a[k]+=t2;
 57                 (a[k]>=md)&&(a[k]-=md);
 58                 a[k+i]=t1-t2;
 59                 (a[k+i]<0)&&(a[k+i]+=md);
 60             }
 61         }
 62     }
 63     if(idx==-1)
 64     {
 65         ull ilen=inv[len];
 66         for(i=0;i<len;++i)
 67             a[i]=a[i]*ilen%md;
 68     }
 69 }
 70 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素);要求len是2的幂
 71 {
 72     static int t1[N],t2[N];
 73     g[0]=poww(f[0],md-2);
 74     for(int i=2,j;i<=len;i<<=1)
 75     {
 76         memcpy(t1,f,sizeof(int)*i);
 77         memcpy(t2,g,sizeof(int)*(i>>1));
 78         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
 79         init(i);
 80         dft(t1,i,1);dft(t2,i,1);
 81         for(j=0;j<i;++j)
 82             t1[j]=ull(t1[j])*t2[j]%md;
 83         dft(t1,i,-1);
 84         for(j=0;j<(i>>1);++j)
 85             t1[j]=t1[j+(i>>1)];
 86         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
 87         dft(t1,i,1);
 88         for(j=0;j<i;++j)
 89             t1[j]=ull(t1[j])*t2[j]%md;
 90         dft(t1,i,-1);
 91         for(j=i>>1;j<i;++j)
 92             g[j]=md-t1[j-(i>>1)];
 93     }
 94 }
 95 inline void p_de(int *f,int len)//derivative求导;f=f'
 96 {
 97     for(int i=0;i<len-1;++i)
 98         f[i]=ull(i+1)*f[i+1]%md;
 99     f[len-1]=0;
100 }
101 inline void p_in(int *f,int len)//integral积分;f=?f
102 {
103     for(int i=len-1;i>=1;--i)
104         f[i]=ull(f[i-1])*inv[i]%md;
105     f[0]=0;
106 }
107 void p_ln(int *f,int len)//要求len为2的幂,f[0]=1
108 {
109     static int t3[N];
110     p_inv(f,t3,len);p_de(f,len);
111     init(len<<1);
112     dft(f,len<<1,1);dft(t3,len<<1,1);
113     for(int i=0;i<(len<<1);++i)
114         f[i]=ull(f[i])*t3[i]%md;
115     dft(f,len<<1,-1);p_in(f,len);
116 }
117 void p_exp(int *f,int *g,int len)//要求len为2的幂,f[0]=0
118 {
119     static int t1[N],t2[N];
120     g[0]=1;
121     for(int i=2,j;i<=len;i<<=1)
122     {
123         memcpy(t1,g,sizeof(int)*(i>>1));
124         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
125         p_ln(t1,i);
126         for(j=0;j<(i>>1);++j)
127             t1[j]=del(f[j+(i>>1)],t1[j+(i>>1)]);
128         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
129         init(i);
130         dft(t1,i,1);
131         memcpy(t2,g,sizeof(int)*(i>>1));
132         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
133         dft(t2,i,1);
134         for(j=0;j<i;++j)
135             t1[j]=ull(t1[j])*t2[j]%md;
136         dft(t1,i,-1);
137         for(j=i>>1;j<i;++j)
138             g[j]=t1[j-(i>>1)];
139     }
140 }
141 void p_sqrt(int *f,int *g,int len)//g=sqrt(f);要求len为2的幂,f[0]=1
142 {
143     static int t1[N],t2[N];
144     g[0]=1;
145     for(int i=2,j;i<=len;i<<=1)
146     {
147         memcpy(t1,g,sizeof(int)*(i>>1));
148         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
149         for(j=0;j<i;++j)
150             addto(t1[j],t1[j]);
151         p_inv(t1,t2,i);
152         memset(t2+i,0,sizeof(int)*i);
153         memcpy(t1,f,sizeof(int)*i);
154         memset(t1+i,0,sizeof(int)*i);
155         init(i<<1);
156         dft(t1,i<<1,1);dft(t2,i<<1,1);
157         for(j=0;j<(i<<1);++j)
158             t1[j]=ull(t1[j])*t2[j]%md;
159         dft(t1,i<<1,-1);
160         for(j=0;j<(i>>1);++j)
161             g[j]=(ull(g[j])*499122177+t1[j])%md;
162         memcpy(g+(i>>1),t1+(i>>1),sizeof(int)*(i>>1));
163     }
164 }
165 int a[N],b[N];
166 int n,n1;
167 int main()
168 {
169     int i,t;
170     inv[1]=1;
171     for(i=2;i<=300000;++i)
172         inv[i]=ull(md-md/i)*inv[md%i]%md;
173     scanf("%d",&n);n1=n;
174     for(i=0;i<n;++i)
175         scanf("%d",a+i);
176     for(t=1;t<n;t<<=1);
177     n=t;
178     p_sqrt(a,b,n);
179     for(i=0;i<n1;++i)
180         printf("%d ",b[i]);
181     return 0;
182 }
View Code

 

posted @ 2019-03-23 13:36  hehe_54321  阅读(273)  评论(0编辑  收藏  举报
AmazingCounters.com