洛谷 [P3723] 礼物

FFT

https://www.luogu.org/problemnew/solution/P3723
重点在于构造卷积的形式

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
using namespace std;
const int MAXN = 400005;
const double PI = acos(-1);
int init() {
    int rv = 0, fh = 1;
    char c = getchar();
    while(c < '0' || c > '9') {
        if(c == '-') fh = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9') {
        rv = (rv<<1) + (rv<<3) + c - '0';
        c = getchar();
    }
    return fh * rv;
}
struct Complex {
    double x, y;
    Complex(double xx = 0.0, double yy = 0.0) {
        x = xx; y = yy;
    }
    Complex operator + (const Complex &u) const {
        return Complex(x + u.x, y + u.y);
    }
    Complex operator - (const Complex &u) const {
        return Complex(x - u.x, y - u.y);
    }
    Complex operator * (const Complex &u) const {
        return Complex(x * u.x - y * u.y, x * u.y + y * u.x);
    }
}a[MAXN], b[MAXN];
int n, m, lim = 1, limcnt, rev[MAXN], num1[MAXN], num2[MAXN], ttt;
long long ans = 0, c;
void fft(Complex a[], int opt) {
    for(int i = 0; i <= lim; i++) {
        if(i < rev[i]) swap(a[i], a[rev[i]]);
    }
    for(int mid = 1; mid < lim; mid <<= 1) {
        Complex wn = Complex(cos(PI / mid), opt * sin(PI / mid));
        for(int R = mid << 1, j = 0; j < lim; j += R) {
            Complex w = Complex(1.0, 0.0);
            for(int k = 0; k < mid; k++) {
                Complex x = a[j + k], y = w * a[j + mid + k];
                a[j + k] = x + y;
                a[j + mid + k] = x - y;
                w = w * wn;
            }
        }
    }
    if(opt == -1) {
        for(int i = 0; i <= lim; i++) {
            a[i].x /= lim;
        }
    }
}
int main() {
    n = init(); m = init();
    for(int i = 1; i <= n; i++) {
        num1[i] = init();
        a[n - i].x = num1[i];
        ans += num1[i] * num1[i];
        ttt += num1[i];
    }
    for(int i = 0; i < n; i++) {
        num2[i] = init();
        b[i].x = b[i + n].x = num2[i];
        ans += num2[i] * num2[i];
        ttt -= num2[i];
    }
    double t = -(double)ttt / n;
    if(t > 0.0) c = (int)(t + 0.5);
    else c = (int) (t - 0.5);
    ans += n * c * c + 2 * c * ttt; 
    while(lim <= n * 3) lim <<= 1, limcnt++;
    for(int i = 0; i <= lim; i++)
        rev[i] = (rev[i>>1]>>1) | ((i&1) << (limcnt - 1)); 
    fft(a, 1); fft(b, 1);
    for(int i = 0; i <= lim; i++) {
        a[i] = a[i] * b[i];
    }
    fft(a, -1);
    int tmp = 0;
    for(int i = n - 1; i <= 2 * n - 1; i++) {
        tmp = max(tmp, (int)(a[i].x + 0.01));
    }
    ans -= 2 * tmp;
    cout << ans << endl;
    return 0;
}
posted @ 2018-05-24 09:56  Mr_Wolfram  阅读(...)  评论(...编辑  收藏