BZOJ3992[SDOI2015]序列统计

题目链接

洛谷

BZOJ

解析

头一回知道原根还可以这么考……

不难想到递推的做法\(dp[i][j]\)表示长度为\(i\),乘积为\(j\)的答案,那么\(dp[i][j \cdot a[i] \ mod \ M] += dp[i - 1][j]\)

首先我们发现\(0\)可以直接丢掉,因为包含\(0\)的序列对答案不产生任何贡献,然后题目说\(M\)是个质数,如果设\(G\)\(M\)的原根,那么\([1, M - 1]\)\(G^1 \ mod \ M, G^2 \ mod \ M, ..., G^{M - 1} \ mod \ M\)一一对应,那么可以用\(G\)头上的指数来代替\([1, M - 1]\)中的一个数,设\(dp[i][j]\)表示长度为\(i\),乘积为\(G^j \ mod \ M\)的答案,则\(dp[i][(j + k)] += dp[i - 1][j]\)

\(f[j]\)表示\(dp[i][j]\)\(g[j]\)表示\(dp[i - 1][j]\)\(h[j]\)表示给定的集合中是否存在\(G^j \ mod \ M\)这个数,那么

\[f[j] = \sum_{k = 1}^{M - 1} g[k] \cdot h[j - k] \]

这是一个卷积

发现初始只有\(f[0] = 1\),所以只需对\(h\)快速幂卷积即可

当然上面都是不考虑指数对\(M - 1\)取模的时候,考虑取模就每次卷积后把\(f[i](i > M - 1)\)加到\(f[i \ mod \ (M - 1)]\)的位置就可以了

代码

不知道为什么好像我的\(FFT\)\(NTT\)都自带大常数……耗时比第一页最慢的都多了一倍加\(1s\)……

#include <cstring>
#include <iostream>
#include <cstdio>
#include <vector>
#define MAXN 8010
 
typedef long long LL;
const LL mod = 1004535809ll;
 
int qpower(int, int, int);
void pre_prime();
void divide(int, std::vector<int> &);
int get_g(int);
void pre_rev(int);
void NTT(int *, int, int);
void mul(int *, int *, int);
 
int N, M, X, SZ, n, g, pre[MAXN], f[MAXN << 2], ans[MAXN << 2], rev[MAXN << 2];
std::vector<int> prime, dvd;
 
inline void inc(int &x, int y) { x += y; if (x >= mod) x -= mod; }
inline void dec(int &x, int y) { x -= y; if (x < 0) x += mod; }
inline int add(int x, int y) { x += y; return x >= mod ? x - mod : x; }
inline int sub(int x, int y) { x -= y; return x < 0 ? x + mod : x; }
 
int main() {
    pre_prime();
    scanf("%d%d%d%d", &N, &M, &X, &SZ);
    g = get_g(M);
    for (int i = 1, j = g; i < M; ++i, j = (LL)j * g % M) pre[j] = i;
    for (int i = 0; i < SZ; ++i) {
        int t; scanf("%d", &t);
        if(t) ++f[pre[t]];
    }
    //debug
    //printf("%d\n", g);
    //for (int i = 0; i < M; ++i) printf("%d ", pre[i]);
    //puts("");
    ans[0] = 1;
    while ((1 << n) < (M << 1)) ++n;
    pre_rev(n);
    while (N) {
        if (N & 1) mul(ans, f, n);
        mul(f, f, n);
        N >>= 1;
    }
    printf("%d\n", ans[pre[X]]);
 
    return 0;
}
void pre_prime() {
    static bool isn_prime[MAXN];
    for (int i = 2; i < MAXN; ++i) {
        if (!isn_prime[i]) prime.push_back(i);
        for (int j = 0; j < prime.size(), i * prime[j] < MAXN; ++j) {
            isn_prime[i * prime[j]] = 0;
            if (i % prime[j] == 0) break;
        }
    }
}
int get_g(int x) {
    divide(x - 1, dvd);
    for (int i = 2; i < x; ++i)
        if (qpower(i, x - 1, x) == 1) {
            bool flag = 1;
            for (int j = 0; j < dvd.size(); ++j)
                if (qpower(i, (x - 1) / dvd[j], x) == 1) { flag = 0; break; }
            if (flag) return i;
        }
}
void divide(int x, std::vector<int> &res) {
    for (int i = 0; i < prime.size(); ++i)
        if (x % prime[i] == 0) {
            res.push_back(prime[i]);
            while (x % prime[i] == 0) x /= prime[i];
        }
}
int qpower(int x, int y, int p) {
    int res = 1;
    while (y) {
        if (y & 1) res = (LL)res * x % p;
        x = (LL)x * x % p; y >>= 1;
    }
    return res;
}
void NTT(int *arr, int sz, int tp) {
    for (int i = 0; i < (1 << sz); ++i)
        if (rev[i] > i) std::swap(arr[i], arr[rev[i]]);
    for (int len = 2, half = 1; len <= (1 << sz); len <<= 1, half <<= 1) {
        int wn = qpower(3, (mod - 1) / len, mod);
        if (tp == -1) wn = qpower(wn, mod - 2, mod);
        for (int i = 0; i < (1 << sz); i += len)
            for (int j = 0, w = 1; j < half; ++j, w = (LL)w * wn % mod) {
                int x = arr[i + j], y = (LL)arr[i + j + half] * w % mod;
                inc(arr[i + j], y); dec(arr[i + j + half] = x, y);
            }
    }
    if (tp == -1) {
        int inv = qpower(1 << sz, mod - 2, mod);
        for (int i = 0; i < (1 << sz); ++i) arr[i] = (LL)arr[i] * inv % mod;
    }
}
void mul(int *a, int *b, int sz) {
    static int tmp[MAXN << 2];
    for (int i = 0; i < (1 << sz); ++i) tmp[i] = b[i];
    NTT(a, sz, 1); NTT(tmp, sz, 1);
    for (int i = 0; i < (1 << sz); ++i) a[i] = (LL)a[i] * tmp[i] % mod;
    NTT(a, sz, -1);
    for (int i = 1; i < M; ++i) inc(a[i], a[i + M - 1]), a[i + M - 1] = 0; 
}
void pre_rev(int sz) {
    for (int i = 0; i < (1 << sz); ++i)
        rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << sz - 1));
}
//Rhein_E
posted @ 2019-03-14 10:02  Rhein_E  阅读(134)  评论(0编辑  收藏  举报