BZOJ4827 [Hnoi2017]礼物 多项式 FFT

原文链接http://www.cnblogs.com/zhouzhendong/p/8823962.html

题目传送门 - BZOJ4827

题意

  有两个长为$n$的序列$x$和$y$,序列$x,y$的第$i$项分别是$x_i,y_i$。

  选择一个序列$A$,现在你可以对它进行如下两种操作:

  $1.$ 得到一个和$A$循环同构的序列$A'$。

  $2.$ 给所有的$A'_i$都加上$c(c\in N^+)$,得到序列$A''$。

  你进行上面两个操作之后,得到的序列分别为$x'',y''$(注意$x,y$两个序列中至少有一个序列没有发生任何变化)

  序列$x''$和$y''$的差异值为

$$\sum_{i=1}^{n}(x''_i-y''_i)^2$$

  问差异值最小为多少。

题解

  我们可以选择任何一个序列进行那两个操作就是相当于对$x$进行操作$1$和下面这个操作:

  $2'.$ 给所有的$x'_i$都加上$c(c\in Z)$,得到序列$x''$。

  于是:

  $$\sum_{i=1}^{n}(x''_i-y_i)^2$$

  $$=\sum_{i=1}^{n}(x'_i-y_i+c)^2$$

  (假装后面的没有'的就当有吧,公式不知道为啥烂掉了)

  $$=\sum_{i=1}^{n}x_i^2+y_i^2+c^2-2x_iy_i+2(x_i-y_i)c$$

  $$=\sum_{i=1}^{n}x_i^2+y_i^2+c^2+(\sum_{i=1}^{n}(x_i-y_i))c-2x_iy_i$$

  发现这是个关于$c$的二次函数再加上一坨$-2x'_iy_i$。

  对于$c$只需要运用一下初中的数学知识即可。

  具体地:

    设$t=\sum_{i=1}^{n}x_i-y_i$.

    则与$c$有关的式子是$nc^2+2tc$。

    这个时候其实$-100$到$100$暴搜应该也可以的。

    然后我们运用初中的数学配个方就可以快速算出$c$的取值,但是注意$c$为整数,所以我们就去周围几个数判几下就可以了。

  对于后面的那个,只需要倍长$y$,翻转$x$,然后套路$FFT$即可。

  具体地:

  为了方便,下面把$x_i$和$y_i$下标看做从$0$开始。

  首先化环为链,我们把$y$倍长。

  然后构造多项式:

  $$h_i=\sum_{j=0}^{n-1}x_jy_{i+j-1}$$

  看起来可以用$FFT$来优化。

  我们只需要翻转$x$数组,得到:

  $$h_i=\sum_{j=0}^{n-1}x_{n-j-1}y_{i+j-1}$$

  让$h_i$整体右移$n-1$位,再重新整理上式得:

  $$h_i=\sum_{j=0}^{i}x_jy_{i-j}$$

  显然这个可以$FFT$。

  其中只有$h_{n-1}\dots h_{2n-2}$是有用的。

  求出来之后,只要求下有效$h_i$的最大值$max$即可。然后让答案减掉$-2max$

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=1<<18;
double PI=acos(-1.0);
int n,m,c,x[N],y[N],t=0;
int s,L,R[N];
LL ans=0,xymax=0;
struct C{
	double r,i;
	C(){}
	C(double a,double b){r=a,i=b;}
	C operator + (C x){return C(r+x.r,i+x.i);}
	C operator - (C x){return C(r-x.r,i-x.i);}
	C operator * (C x){return C(r*x.r-i*x.i,r*x.i+i*x.r);}
}w[N],A[N],B[N];
LL calc(LL c){
	return c*c*n+c*t*2;
}
LL getc_val(int t,int n){
	LL c1=-t/n,c2=c1+1,c3=c1-1;
	return min(min(calc(c1),calc(c2)),calc(c3));
}
void FFT(C a[],int n){
	for (int i=0;i<n;i++)
		if (i<R[i])
			swap(a[i],a[R[i]]);
	for (int t=n>>1,d=1;d<n;d<<=1,t>>=1)
		for (int i=0;i<n;i+=(d<<1))
			for (int j=0;j<d;j++){
				C tmp=w[t*j]*a[i+j+d];
				a[i+j+d]=a[i+j]-tmp;
				a[i+j]=a[i+j]+tmp;
			}
}
int main(){
	scanf("%d%d",&n,&m);
	for (int i=0;i<n;i++)
		scanf("%d",&x[i]);
	for (int i=0;i<n;i++)
		scanf("%d",&y[i]),t+=x[i]-y[i],ans+=x[i]*x[i]+y[i]*y[i];
	ans+=getc_val(t,n);
	for (s=1,L=0;s<n*3;s<<=1,L++);
	for (int i=0;i<s;i++){
		R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
		w[i]=C(cos(2*i*PI/s),sin(2*i*PI/s));
		A[i]=B[i]=C(0,0);
	}
	for (int i=0;i<n;i++)
		A[i]=C(x[n-i-1],0),B[i]=B[i+n]=C(y[i],0);
	FFT(A,s),FFT(B,s);
	for (int i=0;i<s;i++)
		A[i]=A[i]*B[i],w[i].i*=-1.0;
	FFT(A,s);
	for (int i=n-1;i<2*n;i++)
		xymax=max(xymax,(LL)(A[i].r/s+0.5));
	ans-=xymax*2;
	printf("%lld",ans);
	return 0;
}

  

posted @ 2018-04-13 21:07  zzd233  阅读(386)  评论(0编辑  收藏  举报