自己写的矩阵类Matrix

今天老板出差了,闲来无事自己写了一个矩阵类作为休息,想实现像matlab一样强大的功能,今天只是实现了最基本的部分,以后还得多多改进

头文件:

 1 #ifndef _MATRIX_H_
 2 #define _MATRIX_H_
 3 
 4 template <class DataType>
 5 class Matrix
 6 {
 7 public:
 8     Matrix();                                                            //constructor
 9     Matrix(int r, int c);                                                //constructor the matrix of r*c
10     ~Matrix();                                                            //deconstructor
11     void SetMem(int r, int c, DataType v);                                //set the element matrix[r][c]
12     DataType GetMem(int r, int c) const;                                //get the element matrix[r][c]
13     void SetRow(int r, DataType *a, int len);                            //set the value of the r row
14     void SetCol(int c, DataType *a, int len);                            //set the value of the c col
15     void Zeros();                                                        //set all element 0
16     void Ones();                                                        //set all element 1
17 
18     Matrix<DataType>* GetRow(int r);                                    //get the r row
19     Matrix<DataType>* GetCol(int c);                                    //get the c column
20 
21     int GetRows();                                                        //get the number of rows
22     int GetColumns();                                                    //get the number of columns
23     
24     Matrix<DataType>& operator + (Matrix<DataType> &rhl);                //matrix + matrix
25     Matrix<DataType>& operator + (DataType v);                            //matrix + DataType
26     Matrix<DataType>& operator - (Matrix<DataType> &rhl);                //matrix - matrix
27     Matrix<DataType>& operator - (DataType v);                            //matrix - DataType
28     Matrix<DataType>& operator * (Matrix<DataType> &rhl);                //matrix .* matrix
29     Matrix<DataType>& operator * (DataType v);                            //matrix * DataType
30     Matrix<DataType>& operator / (Matrix<DataType> &rhl);                //matrix ./ matrix
31     Matrix<DataType>& operator / (DataType v);                            //matrix / DataType
32     Matrix<DataType>& operator = (Matrix<DataType> &rhl);                //matrix = matrix
33     Matrix<DataType>& operator = (DataType v);                            //matrix = DataType
34     bool operator == (Matrix<DataType> &rhl);                            //return true if the every element is equal,else false
35     DataType operator () (int x, int y);                                //the sub index
36     Matrix<DataType>* operator () (int x, char flag);                    //the sub index,flag=='R',return x row;else 'C',return x column
37     Matrix<DataType>& MatrixMul(Matrix<DataType>& lhs,Matrix<DataType>& rhs);
38     
39     void Show();                                                        //display the matrix
40 private:
41     DataType* data;                                                        //pointer to data
42     int rows;                                                            //row
43     int cols;                                                            //column
44     inline void DeepCopy(Matrix<DataType>& org);                        //copy every element of the matrix
45 };
46 
47 
48 #endif

实现源文件:

  1 #include "Matrix.h"
  2 #include <iostream>
  3 using namespace std;
  4 
  5 template <class DataType>
  6 Matrix<DataType>::Matrix()
  7 {
  8     rows=1;
  9     cols=1;
 10     data=new DataType[rows*cols];
 11     memset(data,0,rows*cols*sizeof(DataType));
 12 }
 13 
 14 template <class DataType>
 15 Matrix<DataType>::Matrix(int r, int c)
 16 {
 17     rows=r;
 18     cols=c;
 19     data=new DataType[rows*cols];
 20     memset(data,0,rows*cols*sizeof(DataType));
 21 }
 22 
 23 template <class DataType>
 24 Matrix<DataType>::~Matrix()
 25 {
 26     delete[] data;
 27 }
 28 
 29 template <class DataType>
 30 void Matrix<DataType>::Show()
 31 {
 32     for (int i=0;i<rows;i++)
 33     {
 34         for (int j=0;j<cols;j++)
 35         {
 36             cout<<data[i*cols+j]<<" ";
 37             if(j==cols-1)
 38                 cout<<endl;
 39         }
 40     }
 41 }
 42 
 43 template <class DataType>
 44 void Matrix<DataType>::SetMem(int r, int c, DataType v)
 45 {
 46     if(r>rows||c>cols)
 47         return;
 48     data[(r-1)*cols+(c-1)]=v;
 49 }
 50 template <class DataType>
 51 DataType Matrix<DataType>::GetMem(int r, int c) const
 52 {
 53     if(r>rows||c>cols)
 54         return -1;
 55     return data[(r-1)*cols+(c-1)];
 56 }
 57 
 58 template <class DataType>
 59 void Matrix<DataType>::SetRow(int r, DataType *a, int len)
 60 {
 61     if (r>rows||len!=cols)
 62         return;
 63     
 64     for (int i=0;i<cols;i++)
 65     {
 66         data[(r-1)*cols+i]=a[i];
 67     }
 68 }
 69 
 70 template <class DataType>
 71 void Matrix<DataType>::SetCol(int c, DataType *a, int len)
 72 {
 73     if(c>cols||len!=rows)
 74         return;
 75     for (int i=0;i<rows;i++)
 76     {
 77         data[i*cols+(c-1)]=a[i];
 78     }
 79 }
 80 
 81 template <class DataType>
 82 Matrix<DataType>* Matrix<DataType>::GetRow(int r)
 83 {
 84     if (r>rows)
 85     {
 86         exit(0);
 87     }
 88     else
 89     {
 90         Matrix<DataType> *tmp=new Matrix<DataType>(1,this->cols);
 91         for (int i=1;i<=this->cols;i++)
 92         {
 93             DataType tm=0;
 94             tm=this->GetMem(r,i);
 95             tmp->SetMem(1,i,tm);
 96         }
 97         return tmp;
 98     }
 99 }
100 
101 template <class DataType>
102 Matrix<DataType>* Matrix<DataType>::GetCol(int c)
103 {
104     if (c>cols)
105     {
106         exit(0);
107     }
108     else
109     {
110         Matrix<DataType> *tmp=new Matrix<DataType>(1,this->rows);
111         DataType tm=0;
112         for (int i=1;i<=this->rows;i++)
113         {
114             tm=this->GetMem(i,c);
115             tmp->SetMem(1,i,tm);
116         }
117         return tmp;
118     }
119 }
120 
121 template <class DataType>
122 int Matrix<DataType>::GetColumns()
123 {
124     return cols;
125 }
126 
127 template <class DataType>
128 int Matrix<DataType>::GetRows()
129 {
130     return rows;
131 }
132 
133 template <class DataType>
134 Matrix<DataType>& Matrix<DataType>::operator + (Matrix<DataType>& rhl)
135 {
136     if (rows!=rhl.GetRows()||cols!=rhl.GetColumns())
137     {
138         exit(0);
139     }
140     else
141     {
142         for (int i=1;i<=rows;i++)
143         {
144             for (int j=1;j<=cols;j++)
145             {
146                 this->SetMem(i,j,this->GetMem(i,j)+rhl.GetMem(i,j));
147             }
148         }
149         return *this;
150     }
151 }
152 
153 template <class DataType>
154 Matrix<DataType>& Matrix<DataType>::operator - (Matrix<DataType>& rhl)
155 {
156     if (rows!=rhl.GetRows()||cols!=rhl.GetColumns())
157     {
158         exit(0);
159     }
160     else
161     {
162         for (int i=1;i<=rows;i++)
163         {
164             for (int j=1;j<=cols;j++)
165             {
166                 this->SetMem(i,j,this->GetMem(i,j)-rhl.GetMem(i,j));
167             }
168         }
169         return *this;
170     }
171 }
172 
173 template <class DataType>
174 Matrix<DataType>& Matrix<DataType>::operator * (Matrix<DataType>& rhs)
175 {
176 
177         if (rows!=rhs.GetRows()||cols!=rhs.GetColumns())
178         {
179             exit(0);
180         }
181         for (int i=1;i<=rows;i++)
182         {
183             for (int j=1;j<=cols;j++)
184             {
185                 this->SetMem(i,j,this->GetMem(i,j)*rhs.GetMem(i,j));
186             }
187         }
188         return *this;
189 }
190 
191 
192 template <class DataType>
193 Matrix<DataType>& Matrix<DataType>::operator / (Matrix<DataType>& rhl)
194 {
195     if (rows!=rhl.GetRows()||cols!=rhl.GetColumns())
196     {
197         exit(0);
198     }
199     else
200     {
201         for (int i=1;i<=rows;i++)
202         {
203             for (int j=1;j<=cols;j++)
204             {
205                 this->SetMem(i,j,this->GetMem(i,j)/rhl.GetMem(i,j));
206             }
207         }
208         return *this;
209     }
210 }
211 
212 template <class DataType>
213 void Matrix<DataType>::DeepCopy(Matrix<DataType>& org)
214 {
215     if(cols!=org.GetColumns()||rows!=org.GetRows())
216         return;
217     for (int i=1;i<=rows;i++)
218     {
219         for (int j=1;j<=cols;j++)
220         {
221             this->SetMem(i,j,org.GetMem(i,j));
222         }
223     }
224 }
225 
226 template <class DataType>
227 Matrix<DataType>& Matrix<DataType>::operator = (Matrix<DataType>& rhl)
228 {
229     if(this==&rhl)
230         return *this;
231     DeepCopy(rhl);
232     return *this;
233 }
234 
235 template <class DataType>
236 void Matrix<DataType>::Zeros()
237 {
238     for (int i=1;i<=rows;i++)
239     {
240         for (int j=1;j<=cols;j++)
241         {
242             this->SetMem(i,j,0.0);
243         }
244     }
245 }
246 
247 template <class DataType>
248 void Matrix<DataType>::Ones()
249 {
250     for (int i=1;i<=rows;i++)
251     {
252         for (int j=1;j<=cols;j++)
253         {
254             this->SetMem(i,j,1.0);
255         }
256     }
257 }
258 
259 template <class DataType>
260 Matrix<DataType>& Matrix<DataType>::operator + (DataType v)
261 {
262     for (int i=1;i<=rows;i++)
263     {
264         for (int j=1;j<=cols;j++)
265         {
266             this->SetMem(i,j,this->GetMem(i,j)+v);
267         }
268     }
269     return *this;
270 }
271 
272 template <class DataType>
273 Matrix<DataType>& Matrix<DataType>::operator - (DataType v)
274 {
275     for (int i=1;i<=rows;i++)
276     {
277         for (int j=1;j<=cols;j++)
278         {
279             this->SetMem(i,j,this->GetMem(i,j)-v);
280         }
281     }
282     return *this;
283 }
284 
285 template <class DataType>
286 Matrix<DataType>& Matrix<DataType>::operator * (DataType v)
287 {
288     for (int i=1;i<=rows;i++)
289     {
290         for (int j=1;j<=cols;j++)
291         {
292             this->SetMem(i,j,this->GetMem(i,j)*v);
293         }
294     }
295     return *this;
296 }
297 
298 template <class DataType>
299 Matrix<DataType>& Matrix<DataType>::operator / (DataType v)
300 {
301     for (int i=1;i<=rows;i++)
302     {
303         for (int j=1;j<=cols;j++)
304         {
305             this->SetMem(i,j,this->GetMem(i,j)/v);
306         }
307     }
308     return *this;
309 }
310 
311 template <class DataType>
312 Matrix<DataType>& Matrix<DataType>::operator = (DataType v)
313 {
314 
315     for (int i=1;i<=rows;i++)
316     {
317         for (int j=1;j<=cols;j++)
318         {
319             this->SetMem(i,j,v);
320         }
321     }
322     return *this;
323 }
324 
325 
326 template <class DataType>
327 bool Matrix<DataType>::operator == (Matrix<DataType> &rhl)
328 {
329     if (rows!=rhl.GetRows()||cols!=rhl.GetColumns())
330     {
331         return false;
332     }
333     else
334     {
335         int count=0;
336         for (int i=1;i<=rows;i++)
337         {
338             for (int j=1;j<=cols;j++)
339             {
340                 if(GetMem(i,j)==rhl.GetMem(i,j))
341                     count++;
342             }
343         }
344         if(count==rows*cols)
345             return true;
346         else
347             return false;
348     }
349 }
350 
351 template <class DataType>
352 DataType Matrix<DataType>::operator () (int x, int y)
353 {
354     return GetMem(x,y);
355 }
356 
357 template <class DataType>
358 Matrix<DataType>* Matrix<DataType>::operator () (int x, char flag)
359 {
360     if (flag=='R')
361     {
362         return GetRow(x);
363     }
364     else 
365     {
366         return GetCol(x);
367     }
368 }
369 //before call the function, must use operator new to apply some storage for the object
370 template <class DataType>
371 Matrix<DataType>& Matrix<DataType>::MatrixMul(Matrix<DataType>& lhs, Matrix<DataType>& rhs)
372 {
373     if (lhs.GetColumns()!=rhs.GetRows())
374     {
375         exit(0);
376     }
377     for (int i=1;i<=lhs.GetRows();i++)
378     {
379         for (int j=1;j<=rhs.GetColumns();j++)
380         {
381             DataType tm=0;
382             for (int k=1;k<=lhs.GetColumns();k++)
383             {
384                 tm+=GetMem(i,k)*rhs.GetMem(k,j);
385             }
386             this->SetMem(i,j,tm);
387         }
388     }
389     return *this;    
390 }

测试文件:

 1 #include <iostream>
 2 #include "Matrix.cpp"   //use the template, so must include the cpp file
 3 using namespace std;
 4 
 5 void main()
 6 {
 7     Matrix<double> dMat(2,4);
 8     dMat.SetMem(2,2,10.0);
 9     dMat.SetMem(2,4,5.0);
10     dMat.Show();
11 
12     Matrix<double> dMat1(4,2);
13     dMat1.SetMem(2,1,10.0);
14     dMat1.SetMem(2,2,3.0);
15     dMat1.Show();
16     
17     Matrix<double> *pMat=new Matrix<double>(4,4);
18     pMat->MatrixMul(dMat1,dMat);
19     pMat->Show();
20 }

以后再继续完善这个类,现在先去吃饭了。(*^__^*) 嘻嘻……

posted @ 2012-09-14 17:28  lscheng  阅读(1220)  评论(0编辑  收藏  举报