Strassen算法
Strassen
快速矩阵乘法,可以做到 \(\Theta(n^{\log_27})\)
考虑把一个矩阵化为 \(2^k\times2^k\) 的形式,然后分为四个角上大小均等的子矩阵
如果有 \(C=A*B\)
那么不妨记
\[A=\left(
\begin{array}{l}
A_{1,1} & A_{1,2} \\
A_{2,1} & A_{2,2}
\end{array}\right)\\
B=\left(
\begin{array}{l}
B_{1,1} & B_{1,2} \\
B_{2,1} & B_{2,2}
\end{array}\right)\\
C=\left(
\begin{array}{l}
C_{1,1} & C_{1,2} \\
C_{2,1} & C_{2,2}
\end{array}\right)\\
\]
设
\[S_1=B_{1,2}-B_{2,2}\\
S_2=A_{1,1}+A_{1,2}\\
S_3=A_{2,1}+A_{2,2}\\
S_4=B_{2,1}-B_{1,1}\\
S_5=A_{1,1}+A_{2,2}\\
S_6=B_{1,1}+B_{2,2}\\
S_7=A_{1,2}-A_{2,2}\\
S_8=B_{2,1}+B_{2,2}\\
S_9=A_{1,1}-A_{2,1}\\
S_{10}=B_{1,1}+B_{1,2}
\]
再设
\[P_1=A_{1,1}S_1\\
P_2=S_2B_{2,2}\\
P_3=S_3B_{1,1}\\
P_4=A_{2,2}S_4\\
P_5=S_5S_6\\
P_6=S_7S_8\\
P_7=S_9S_{10}
\]
代换得
\[\begin{align*}
C_{1,1}&=P_5+P_4-P_2+P_6\\
C_{1,2}&=P_1+P_2\\
C_{2,1}&=P_3+P_4\\
C_{2.2}&=P_5+P_1-P_3-P_7
\end{align*}
\]
做加减是 \(\Theta(n^2)\) 的,于是
\[T(n)=\Theta(n^2)+7T(\frac n2)
\]
主定理即可
\[T(n)=\Theta(n^{\log_27})
\]
给出一个常数较小且比较简单的实现
#define mod 1000000007
inline int add(int x,int y)
{
int z=x+y;
return z>=mod?z-mod:z;
}
inline int ast(ll x,int y)
{
ll z=x*y;
return z>=mod?z%mod:z;
}
#undef assert
#define assert(_expr) for (; !(_expr); \
__builtin_exit(114514) )
#define CLEAR_MARTIX 2
typedef int** Int;
class martix
{
public:
int x;
Int a;
void clr()
{
for(int i=0;i<x;++i) delete []a[i];
delete []a;
}
void print()
{
for(int i=0;i<x;++i,putchar('\n'))
for(int j=0;j<x;++j)
write_(a[i][j],' ');
}
martix()=default;
martix(int len,int unit=0)
{
x=len;a=new int* [x];
for(int i=0;i<x;++i) a[i]=new int [x];
if(unit)
{
for(int i=0;i<x;++i)
for(int j=0;j<x;++j)
a[i][j]=0;
if(unit==1) for(int i=0;i<x;++i) a[i][i]=1;
}
}
friend martix operator + (const martix& a,const martix& b)
{
assert(a.x==b.x) {debug_out("Not equal!");}
martix c(a.x);
int len=a.x;
for(int i=0;i<len;++i)
for(int j=0;j<len;++j)
c.a[i][j]=add(a.a[i][j],b.a[i][j]);
return c;
}
friend martix operator - (const martix& a,const martix& b)
{
assert(a.x==b.x) {debug_out("Not equal!");}
martix c(a.x);
int len=a.x;
for(int i=0;i<len;++i)
for(int j=0;j<len;++j)
c.a[i][j]=add(a.a[i][j],mod-b.a[i][j]);
return c;
}
friend martix operator * (const martix& a,const martix& b)
{
assert(a.x==b.x) {debug_out("Not equal!");}
martix c(a.x,CLEAR_MARTIX);
int len=a.x;
for(int i=0;i<len;++i)
for(int k=0;k<len;++k)
{
int s=a.a[i][k];
for(int j=0;j<len;++j)
c.a[i][j]=add(c.a[i][j], ast(s,b.a[k][j]) );
}
return c;
}//普通乘法
};
inline void split(martix &a,martix &a11,martix &a12,martix &a21,martix &a22)
{
if(a.x&1) return ;
int len=a.x>>1;
a11=martix(len),a12=martix(len),
a21=martix(len),a22=martix(len);
for(int i=0;i<len;++i)
for(int j=0;j<len;++j)
a11.a[i][j]=a.a[i][j],
a12.a[i][j]=a.a[i][j+len],
a21.a[i][j]=a.a[i+len][j],
a22.a[i][j]=a.a[i+len][j+len];
}
inline void merge(martix &a,martix &a11,martix &a12,martix &a21,martix &a22)
{
assert(a11.x==a12.x){debug_out("Not equal!");}
int len=a11.x;
a=martix(len<<1);
for(int i=0;i<len;++i)
for(int j=0;j<len;++j)
a.a[i][j]=a11.a[i][j],
a.a[i][j+len]=a12.a[i][j],
a.a[i+len][j]=a21.a[i][j],
a.a[i+len][j+len]=a22.a[i][j];
}
inline martix Strassen(martix& a,martix& b)
{
assert(a.x==b.x) {debug_out("Not equal!");}
if(a.x&1) return a*b;
martix s1,s2,s3,s4,s5,s6,s7,s8,s9,s10;
martix p1,p2,p3,p4,p5,p6,p7;
martix a11,a12,a21,a22;
martix b11,b12,b21,b22;
martix c11,c12,c21,c22;
split(a,a11,a12,a21,a22);
split(b,b11,b12,b21,b22);
s1 =b12 - b22;
s2 =a11 + a12;
s3 =a21 + a22;
s4 =b21 - b11;
s5 =a11 + a22;
s6 =b11 + b22;
s7 =a12 - a22;
s8 =b21 + b22;
s9 =a11 - a21;
s10=b11 + b12;
p1=Strassen(a11,s1);
p2=Strassen(s2,b22);
p3=Strassen(s3,b11);
p4=Strassen(a22,s4);
p5=Strassen(s5,s6 );
p6=Strassen(s7,s8 );
p7=Strassen(s9,s10);
c11=p5+p4-p2+p6;
c12=p1+p2;
c21=p3+p4;
c22=p5+p1-p3-p7;
martix c;
merge(c,c11,c12,c21,c22);
return c;
}
完全就是小品呢~

浙公网安备 33010602011771号