ACM学习历程—HDU5667 Sequence(数论 && 矩阵乘法 && 快速幂)

http://acm.hdu.edu.cn/showproblem.php?pid=5667

这题的关键是处理指数,因为最后结果是a^t这种的,主要是如何计算t。

发现t是一个递推式,t(n) = c*t(n-1)+t(n-2)+b。这样的话就可以使用矩阵快速幂进行计算了。

设列矩阵[t(n), t(n-1), 1],它可以由[t(n-1), t(n-2), 1]乘上一个3*3的矩阵得到这个矩阵为:{[c, 1, b], [1, 0, 0], [0, 0, 1]},这样指数部分就可以矩阵快速幂了。

但是如果指数不模的话,计算肯定爆了,这里需要考虑费马小定理,a^(p-1) = 1(mod p),于是指数就可以模(p-1)了。

最后算出指数后,再来一次快速幂即可。

但是打这场BC的时候,我并没有考虑到a%p = 0的情况。。。最终错失这题,只过了三题。

 

代码:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <set>
#include <map>
#include <queue>
#include <vector>
#include <string>
#define LL long long

using namespace std;

//矩阵乘法
//方阵
#define maxN 4
struct Mat
{
    LL val[maxN][maxN], p;
    int len;

    Mat()
    {
        len = 3;
    }

    Mat operator=(const Mat& a)
    {
        len = a.len;
        p = a.p;
        for (int i = 0; i < len; ++i)
            for (int j = 0; j < len; ++j)
                val[i][j] = a.val[i][j];
        return *this;
    }

    Mat operator*(const Mat& a)
    {
        Mat x;
        x.p = a.p;
        memset(x.val, 0, sizeof(x.val));
        for (int i = 0; i < len; ++i)
            for (int j = 0; j < len; ++j)
                for (int k = 0; k < len; ++k)
                    if (val[i][k] && a.val[k][j])
                        x.val[i][j] = (x.val[i][j] + val[i][k]*a.val[k][j]%p)%p;
        return x;
    }

    Mat operator^(const LL& a)
    {
        LL n = a;
        Mat x, p = *this;
        memset(x.val, 0, sizeof(x.val));
        x.p = this->p;
        for (int i = 0; i < len; ++i)
            x.val[i][i] = 1;
        while (n)
        {
            if (n & 1)
                x = x * p;
            p = p * p;
            n >>= 1;
        }
        return x;
    }
}from, mat;

LL n, a, b, c, p;

//快速幂m^n
LL quickPow(LL x, LL n)
{
    LL a = 1;
    while (n)
    {
        a *= n&1 ? x : 1;
        a %= p;
        n >>= 1 ;
        x *= x;
        x %= p;
    }
    return a;
}

void work()
{
    if (a%p == 0)
    {
        if (n == 1) printf("1\n");
        else printf("0\n");
        return;
    }
    LL t, ans;
    if (n == 1)
        t = 0;
    else if (n == 2)
        t = b%(p-1);
    else
    {
        memset(from.val, 0, sizeof(from.val));
        from.val[0][0] = c;
        from.val[0][1] = 1;
        from.val[0][2] = b;
        from.val[1][0] = 1;
        from.val[2][2] = 1;
        from.len = 3;
        from.p = p-1;
        mat = from^(n-2);
        t = (mat.val[0][0]*b%(p-1)+mat.val[0][2])%(p-1);
    }
    ans = quickPow(a, t);
    cout << ans << endl;
}

int main()
{
    //freopen("test.in", "r", stdin);
    int T;
    scanf("%d", &T);
    for (int times = 1; times <= T; ++times)
    {
        cin >> n >> a >> b >> c >> p;
        work();
    }
    return 0;
}
View Code

 

posted on 2016-04-24 19:28  AndyQsmart  阅读(485)  评论(0编辑  收藏  举报

导航