bzoj3992

fft+数论

抄了很多地方

首先我们有个暴力dp,dp[i][j]表示到第i个数,乘积%m=j的方案数

那么可以暴力转移。

这个东西很明显不能矩阵快速幂,又不是卷积,但是我们可以转化

还记得原根吗

g^0...g^P-2 %P的余数各不相同,我们可以用指标代替数

指标就是上面的幂

设a的指标为ind[a]

因为j=1->m-1,那么ind∈[0,m-1)可以表示[1,m)的所有数

ind[a*b]=(ind[a]+ind[b])%(m-1)

这样就可以构造卷积转移了,但是每次fft完要把i+m-1的东西加到i上,然后清零

复杂度nmlogm

卷积满足交换律和结合律

那么我们用多项式快速幂就行了

初始的ans是ans[0]=1.因为ind[1]=0

复杂度mlogmlogn

#include<bits/stdc++.h>
using namespace std;
const int N = 32005, P = 1004535809;
typedef long long ll;
int t, m, x, s, n, k, g;
ll a[N], b[N], ans[N], ind[N];
ll power(ll x, ll t, ll P)
{
    ll ret = 1;
    for(; t; t >>= 1, x = x * x % P) if(t & 1) ret = ret * x % P;
    return ret;
}
void ntt(ll *a, int f)
{
    for(int i = 0; i < n; ++i)
    {
        int t = 0;
        for(int j = 0; j < k; ++j) if(i >> j & 1) t |= 1 << (k - j - 1);
        if(i < t) swap(a[i], a[t]);
    }
    for(int l = 2; l <= n; l <<= 1)
    {
        ll w = power(3, f == 1 ? (P - 1) / l : P - 1 - (P - 1) / l, P);
        int m = l >> 1;
        for(int i = 0; i < n; i += l) 
        {
            ll t = 1;
            for(int k = 0; k < m; ++k, t = t * w % P)
            {
                ll x = a[i + k], y = t * a[i + m + k];
                a[i + k] = (x + y) % P;
                a[i + m + k] = ((x - y) % P + P) % P;
            }
        }        
    }
    if(f == -1)
    {
        ll inv = power(n, P - 2, P);
        for(int i = 0; i < n; ++i) a[i] = a[i] * inv % P;
    }
}
void sqr(ll *a)
{
    ntt(a, 1);
    for(int i = 0; i < n; ++i) a[i] = a[i] * a[i] % P;
    ntt(a, - 1);
    for(int i = 0; i < m - 1; ++i) a[i] = (a[i] + a[i + m - 1]) % P, a[i + m - 1] = 0;
}
ll get_(int m)
{
    if(m == 2) return 1;
    for(ll g = 2; g < m; ++g)
    {  
        bool flag =  1;
        ll lim = sqrt(m);
        for(ll i = 2; i <= lim; ++i) if((m - 1) % i == 0)
        {
            if(power(g, (m - 1) / i, m) == 1)
            {
                flag = false;
                break;
            }
        }
        if(flag) return g;
    }
}
int main()
{
    scanf("%d%d%d%d", &t, &m, &x, &s);
    n = 1;
    k = 0;
    while(n <= 2 * m) n <<= 1, ++k;
    g = get_(m);
    ll A = 1;
    for(int i = 0; i < m - 1; ++i, A = A * g % m) ind[A] = i; 
    ans[0] = 1; 
    for(int i = 1; i <= s; ++i) 
    {
        int x;
        scanf("%d", &x);
        if(x) a[ind[x]] = 1;
    }
    for(; t; t >>= 1, sqr(a)) if(t & 1) 
    {
        for(int i = 0; i < n; ++i) b[i] = a[i];
        ntt(b, 1);
        ntt(ans, 1);
        for(int i = 0; i < n; ++i) ans[i] = ans[i] * b[i] % P;
        ntt(ans, -1);
        for(int i = 0; i < m - 1; ++i) ans[i] = (ans[i] + ans[i + m - 1]) % P, ans[i + m - 1] = 0;
    }
    printf("%lld\n", ans[ind[x]]);
    return 0;
}
View Code

 

posted @ 2017-12-18 17:23  19992147  阅读(164)  评论(0编辑  收藏  举报