Codeforces 1096G. Lucky Tickets【生成函数】

LINK

题目大意

很简单自己看

思路

考虑生成函数(为啥tags里面有一个dp啊)

显然,每一个指数上是否有系数是由数集中是否有这个数决定的

有的话就是1没有就是0

然后求出这个生成函数的\(\frac{n}{2}\)次方

把每一项的系数全部平方加起来。。没了


#include<bits/stdc++.h>

using namespace std;

typedef vector<int> Poly;

const int N = 3e6 + 10;
const int Mod = 998244353;
const int G = 3;

int add(int a, int b, int mod = Mod) {
  return (a += b) >= mod ? a - mod : a;
}

int sub(int a, int b, int mod = Mod) {
  return (a -= b) < 0 ? a + mod : a;
}

int mul(int a, int b, int mod = Mod) {
  return 1ll * a * b % mod;
}

int fast_pow(int a, int b, int mod = Mod) {
  int res = 1;
  for (; b; b >>= 1, a = mul(a, a, mod))
    if (b & 1) res = mul(res, a, mod);
  return res;
}

int w[N][2];

void init() {
  for (int i = 1; i < (1 << 21); i <<= 1) {
    w[i][0] = w[i][1] = 1;
    int wn = fast_pow(G, (Mod - 1) / (i << 1));
    for (int j = 1; j < i; j++)
      w[i + j][0] = mul(w[i + j - 1][0], wn);
    wn = fast_pow(G, Mod - 1 - (Mod - 1) / (i << 1));
    for (int j = 1; j < i; j++)
      w[i + j][1] = mul(w[i + j - 1][1], wn);
  }
}

void transform(int *t, int len, int typ) {
  for (int i = 0, j = 0, k; j < len; j++) {
    if (i > j) swap(t[i], t[j]);
    for (k = (len >> 1); k & i; k >>= 1) i ^= k;
    i ^= k; 
  }
  for (int i = 1; i < len; i <<= 1) {
    for (int j = 0; j < len; j += i << 1) {
      for (int k = 0; k < i; k++) {
        int x = t[j + k], y = mul(t[j + k + i], w[i + k][typ]);
        t[j + k] = add(x, y);
        t[j + k + i] = sub(x, y); 
      }
    } 
  }
  if (typ) return;
  int invlen = fast_pow(len, Mod - 2);
  for (int i = 0; i < len; i++)
    t[i] = mul(t[i], invlen);
}

Poly fast_pow(Poly a, int b) {
  int len = 1 << (int) ceil(log2(a.size()));
  a.resize(len);
  transform(&a[0], len, 1);
  for (int i = 0; i < len; i++)
    a[i] = fast_pow(a[i], b);
  transform(&a[0], len, 0);
  return a;
}

int n, k;

int main() {
  init();
  scanf("%d %d", &n, &k);
  Poly a((int) 2e6);
  for (int i = 1; i <= k; i++) {
    int x; 
    scanf("%d", &x);
    a[x] = 1; 
  }
  a = fast_pow(a, n / 2);
  int ans = 0;
  for (int i = 0; i < (signed) a.size(); i++)
    ans = add(ans, mul(a[i], a[i]));
  printf("%d", ans);
  return 0;
}
posted @ 2019-01-06 23:29 Dream_maker_yk 阅读(...) 评论(...) 编辑 收藏