poj 3233 Matrix Power Series (矩阵快速幂 + 二分)

http://poj.org/problem?id=3233

题意:  

题意:已知一个n*n的矩阵A,和一个正整数k,求S A A2 A3 + … + Ak

 

/*第一次写时,写挫啦,tle 一次,后来,稍微改动了一下,ac

矩阵快速幂。首先我们知道 A^x 可以用矩阵快速幂求出来。
其次可以对k进行二分,
每次将规模减半,分k为奇偶两种情况,如当k = 6和k = 7时有:
 S(6) = (1 + A^3) * (A + A^2 + A^3) = (1 + A^3) * S(3)。
 s(7)  = (1 + A^3) * (A + A^2 + A^3) + A^7 = (1 + A^3)*(s(3)) + A^7;
*/

#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<string>
#define Min(a,b) a<b?a:b
#define Max(a,b) a>b?a:b
#define CL(a,num) memset(a,num,sizeof(a));
#define maxn  40
#define eps  1e-6
#define inf 9999999
#define mx 1<<60

using namespace std;
struct   martrix
{
    int m[31][31];

};
int n,mod,k;
 martrix  mtadd (martrix a,martrix b)
   {
       int i ,j;
       martrix c;
       for(i = 0;i < n;i++)
       {

           for(j = 0 ; j< n;j++)
           {
               c.m[i][j] = 0;
               c.m[i][j] = (a.m[i][j] + b.m[i][j])%mod;

           }
       }
       return c;
   }

martrix mtmul(martrix a,martrix b)
{
    martrix c;
    int i,j,k;
    for(i = 0; i < n; i++)
    {
        for(j = 0; j < n;j++)
        {
            c.m[i][j] = 0;
            for(k = 0 ; k < n;k++)
            {
                c.m[i][j] += a.m[i][k] * b.m[k][j];
                c.m[i][j] %=mod;
            }
        }
    }


    return c;
}
martrix mtpow(martrix d,int k)
{   martrix a;
    if(k == 1return d ;
    int mid = k / 2;
    a = mtpow(d,k/2);
    a = mtmul(a,a);
    if(k & 1)
    {
        a = mtmul(a,d);
    }
    return a;


}
martrix solve(martrix a,int k)
{
    martrix b,c,d;
    if(k == 1return a ;
    int mid = k / 2;
    b = mtpow(a,mid);
    d = solve(a,mid);

    c = mtmul(b,d) ;
    c = mtadd(c,d);

    if(k&1)
    {
        c = mtadd(mtpow(a,k),c);
    }
    return  c;
}
int main()
{

    martrix a,b;

     int i,j;
    while(scanf("%d%d%d",&n,&k,&mod)!=EOF)
    {
        for(i = 0; i < n;i++)
        {
            for(j = 0; j < n;j++)
             scanf("%d",&a.m[i][j]);
        }
        b = solve(a,k);
        for(i =0 ; i < n;i++)
        {
            for(j = 0; j <n ;j++)
            {
                if(j == 0)printf("%d",b.m[i][j] % mod);
                else printf(" %d",b.m[i][j] % mod);
            }
            printf("\n");
        }
    }
}
posted @ 2012-08-20 09:57  Szz  阅读(196)  评论(0编辑  收藏  举报