Gym 102920 H. Needle(FFT)
题目链接
题意:
给定数组a、b、c,求令\(a_i+c_j=2*b_k\)的方案数。
思路:
将数组\(a、c\)看成多项式,数组值分别为存在的多项式指数。则\(a_i+c_j=2*b_k\)的方案数为两多项式相乘后指数为\(2*b_k\)的系数。
code:
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <deque>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
// #include <unordered_map>
#define fi first
#define se second
#define pb push_back
// #define endl "\n"
#define debug(x) cout << #x << ":" << x << endl;
#define bug cout << "********" << endl;
#define all(x) x.begin(), x.end()
#define lowbit(x) x & -x
#define fin(x) freopen(x, "r", stdin)
#define fout(x) freopen(x, "w", stdout)
#define ull unsigned long long
#define ll long long
const double eps = 1e-15;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const double pi = acos(-1.0);
const int mod = 998244353;
const int maxn = 1e6 + 10;
using namespace std;
struct Complex{
double x, y;
Complex(double _x = 0.0, double _y = 0.0){
x = _x, y = _y;
}
Complex operator +(const Complex &a)const{
return Complex(x + a.x, y + a.y);
}
Complex operator -(const Complex &a)const{
return Complex(x - a.x, y - a.y);
}
Complex operator *(const Complex &a)const{
return Complex(x * a.x - y * a.y, x * a.y + y * a.x);
}
};
int rev[maxn];
void change(Complex y[], int len){
for(int i = 0; i < len; i ++){
rev[i] = rev[i >> 1] >> 1;
if(i & 1)rev[i] |= len >> 1;
}
for(int i = 0; i < len; i ++){
if(i < rev[i])swap(y[i], y[rev[i]]);
}
}
void fft(Complex y[], int len, int on){
change(y, len);
for(int h = 1; h <= len; h <<= 1){
Complex wn = Complex(cos(2 * pi / h), sin(on * 2 * pi / h));
for(int i = 0; i < len; i += h){
Complex w = Complex(1, 0);
for(int k = i; k < i + h/2; k ++){
Complex u = y[k];
Complex t = w * y[k + h/2];
y[k] = u + t;
y[k + h/2] = u - t;
w = w * wn;
}
}
}
if(on == -1){
for(int i = 0; i < len; i ++)y[i].x /= len;
}
}
Complex x1[maxn], x2[maxn];
int sum[maxn], b[maxn];
int n, m, l;
int main(){
int maxx = 0, d = 3e4 + 10, a;
scanf("%d", &n);
for(int i = 1; i <= n; i ++){
scanf("%d", &a);
a += d, x1[a].x ++, maxx = max(maxx, a);
}
scanf("%d", &m);
for(int i = 1; i <= m; i ++)scanf("%d", &b[i]), b[i] += d;
scanf("%d", &l);
for(int i = 1; i <= l; i ++){
scanf("%d", &a);
a += d, x2[a].x ++, maxx = max(maxx, a);
}
int len = 1;
while(len < 2 * maxx)len <<= 1;
fft(x1, len, 1);
fft(x2, len, 1);
for(int i = 0; i < len; i ++)x1[i] = x1[i] * x2[i];
fft(x1, len, -1);
for(int i = 0; i < len; i ++)sum[i] = x1[i].x + 0.5;
ll ans = 0;
for(int i = 1; i <= m; i ++)ans += sum[2 * b[i]];
printf("%lld\n", ans);
return 0;
}

浙公网安备 33010602011771号