NYOJ 301 递推求值

第一次写博客,拿个矩阵快速幂练练手吧。

首先什么是快速幂,快速幂是让复杂度由线性降为log n的算法,比如8^1024次方暴力要算1024次,但是矩阵快速幂只算10次就好。

此题只不过是把快速幂的底数变为一个矩阵,矩阵乘法手写,然后计算矩阵的n次方的时候使用快速幂。

此题和矩阵n次方有什么关系?

我们先来构造如下矩阵A:

f2  0   0
f1  0   0
1   0   0
和另一个为矩阵B:
b   a   c
1   0   0
0   0   1

矩阵A每乘一次矩阵B,新矩阵第一个值便是递推公式的下一个值。

n的值为10^9,如果乘10^9个矩阵B必然会超时,所以使用快速幂,在中取模就好了(过程取模对乘法来说对结果无影响)。

理论上快速幂使用的时候最大的n能到2^(10^8)这么多。

代码如下:

 1 #include<iostream>
 2 #include<stdio.h>  
 3 #include<string.h>  
 4 using namespace std;
 5 #define mod 1000007  
 6 #define ll long long
 7 
 8 struct matri  
 9 {  
10     ll mat[3][3];  
11 } one= {1,0,0,0,1,0,0,0,1};  
12 
13 matri mul(matri a, matri b) 
14 {  
15     matri res;  
16     for(int i=0;i<3;i++)  
17         for(int j=0;j<3;j++)  
18         {  
19             res.mat[i][j] = 0;  
20             for(int k=0;k<3;k++)  
21             {  
22                 res.mat[i][j] += a.mat[i][k] * b.mat[k][j];  
23                 res.mat[i][j] %= mod;  
24             }  
25         }  
26     return res;  
27 }  
28   
29 matri pow(matri a, ll n)
30 {  
31     matri res = one;  
32     while(n != 0)  
33     {  
34         if(n & 1)  
35             res = mul(res, a);  
36         a = mul(a, a);  
37         n >>= 1;  
38     }  
39     return res;  
40 }  
41   
42 int main()  
43 {  
44     ll n,f1,f2,a,b,c,T;  
45     matri tmp,arr;  
46     scanf("%lld",&T);  
47     while(T--)  
48     {  
49         scanf("%lld%lld%lld%lld%lld%lld",&f1,&f2,&a,&b,&c,&n);  
50         if(n==0)  
51             printf("%lld\n",(f2-f1*b-c+mod)%mod);  
52         if(n==1)  
53             printf("%lld\n",(f1+mod)%mod);  
54         else if(n==2)  
55             printf("%lld\n",(f2+mod)%mod);  
56         else  
57         {  
58             memset(arr.mat, 0, sizeof(arr.mat));  
59             arr.mat[0][0] = f2;arr.mat[1][0] = f1; arr.mat[2][0] = 1;  
60             tmp.mat[0][0] = b; tmp.mat[0][1] = a;  tmp.mat[0][2] = c;  
61             tmp.mat[1][0] = tmp.mat[2][2] = 1;  
62             tmp.mat[1][1] = tmp.mat[1][2] = tmp.mat[2][0] = tmp.mat[2][1] = 0;  
63             matri p = pow(tmp, n-2);  
64             p = mul(p, arr);  
65             ll ans = (p.mat[0][0] + mod) % mod;  
66             printf("%lld\n",ans);  
67         }  
68     }  
69     return 0;  
70 }  

 

posted @ 2018-04-24 01:25  fantastic123  阅读(198)  评论(1)    收藏  举报