由一维数组表示的N维数组实现(C++)

工作中,经常需要表示多维数组(如二维矩阵),常见的做法是使用T **pArr;

T **pArr = new T*[M];//创建二维数组[M][N]
for (int i=0;i<M;i++)
{
    pArr[i] = new T[N];
}

销毁内存:

for (int i=0;i<M;i++)
{
    delete[] pArr[i];
}
delete[] pArr;

若是三维数组,需要创建三次,T*** pArr;以此类推,操作繁琐。

 

为方便动态生成多维数组,本文使用一维数据表示多维数组,并基于C++模板和运算符重载,使用一维数组动态表示多维数组,支持数据切片,简化多维数组的处理和使用。

#ifndef MYNDARRAY_H_20170226
#define MYNDARRAY_H_20170226

#include <assert.h>
#include <iostream>
using namespace std;

//on dim array represent n-dim array
typedef unsigned int iNum;//使用unsigned int作为后续整数型数据类型

typedef struct stMyRange{//闭区间[start,end]
    iNum Start;//start 和end 从 0开始
    iNum End;//
    stMyRange(){
        Start = End = 0;
    }
    stMyRange(iNum start,iNum end){
        assert(start >= 0 && end >= start);
        Start = start;
        End = end;
    }
}stMyRange;

typedef struct stMyRank{//rank代表一个数据维度
    iNum Rank;//数据维度的长度
    iNum Start;//数据维度开始位置,start from 0
    iNum End;//数据维度结束位置,start from 0
    stMyRank(){
        Rank = 0;//if dim == 0, this rank is not exist
        Start = 0;//if start ==0 && end == Rank-1, reprsent the whole rank;
        End = 0;
    }
    void SetRank(iNum rank){
        assert(rank > 0);
        Rank = rank;
        Start = 0;
        End = Rank - 1;
    }
    void SetRange(stMyRange rang){
        assert(rang.Start >= 0 && rang.End >= rang.Start && rang.End < Rank);
        Start = rang.Start;
        End = rang.End;
    }
}stMyRank;

//目前设置最大数据维度为4,后续根据需要可扩展
#define MyMaxRank 4 template<class T> class MyNDArray{//C++模板使用,可表示多种类型的数据 protected: T* m_pData;//row based storage,行优先存储 stMyRank m_Rank[MyMaxRank];//记录数据维度 iNum m_iRankCap;//total rank capacity;数据维度个数 bool m_isSlice;//if slice==true, it will reuse the m_pData of its parent;//是否是数据切片 public: MyNDArray(){ m_pData = NULL; m_iRankCap = 0; m_isSlice = false; } MyNDArray(iNum rank1){ m_Rank[0].SetRank(rank1); m_iRankCap = 1; m_isSlice = false; Init(); } MyNDArray(iNum rank1,iNum rank2){ m_Rank[0].SetRank(rank1); m_Rank[1].SetRank(rank2); m_iRankCap = 2; m_isSlice = false; Init(); } MyNDArray(iNum rank1,iNum rank2,iNum rank3){ m_Rank[0].SetRank(rank1); m_Rank[1].SetRank(rank2); m_Rank[2].SetRank(rank3); m_iRankCap = 3; m_isSlice = false; Init(); } MyNDArray(iNum rank1,iNum rank2,iNum rank3,iNum rank4){ m_Rank[0].SetRank(rank1); m_Rank[1].SetRank(rank2); m_Rank[2].SetRank(rank3); m_Rank[3].SetRank(rank4); m_iRankCap = 4; m_isSlice = false; Init(); } //创建数据切片,新的ndarray共享partent的pData数据域 MyNDArray(MyNDArray& parent,stMyRank rank1){ m_Rank[0]=rank1; m_iRankCap = 1; m_isSlice = true; m_pData = parent.GetData(); } MyNDArray(MyNDArray& parent,stMyRank rank1,stMyRank rank2){ m_Rank[0]=rank1; m_Rank[1]=rank2; m_iRankCap = 2; m_isSlice = true; m_pData = parent.GetData(); } MyNDArray(MyNDArray& parent,stMyRank rank1,stMyRank rank2,stMyRank rank3){ m_Rank[0]=rank1; m_Rank[1]=rank2; m_Rank[2]=rank3; m_iRankCap = 3; m_isSlice = true; m_pData = parent.GetData(); } MyNDArray(MyNDArray& parent,stMyRank rank1,stMyRank rank2,stMyRank rank3,stMyRank rank4){ m_Rank[0]=rank1; m_Rank[1]=rank2; m_Rank[2]=rank3; m_Rank[3]=rank4; m_iRankCap = 4; m_isSlice = true; m_pData = parent.GetData(); } ~MyNDArray(){ Clear(); } void Clear(){ if(!m_isSlice && m_pData != NULL){ delete[] m_pData; m_pData = NULL; } } T* GetData(){ return m_pData; } void SetData(T* pData){ m_pData = pData; } iNum RankCap(){ return m_iRankCap; } iNum RankSize(iNum rank){//start from 1,获取第i维的维度长度 if(rank < 1) return 0; return m_Rank[rank-1].End - m_Rank[rank-1].Start + 1; } MyNDArray Slice(stMyRange rang1){//range start from 0,数据切片 stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); MyNDArray res(*this,rank1); return res; } MyNDArray Slice(stMyRange rang1,stMyRange rang2){ stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); stMyRank rank2 = m_Rank[1]; rank2.SetRange(rang2); MyNDArray res(*this,rank1,rank2); return res; } MyNDArray Slice(stMyRange rang1,stMyRange rang2,stMyRange rang3){ stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); stMyRank rank2 = m_Rank[1]; rank2.SetRange(rang2); stMyRank rank3 = m_Rank[2]; rank3.SetRange(rang3); MyNDArray res(*this,rank1,rank2,rank3); return res; } MyNDArray Slice(stMyRange rang1,stMyRange rang2,stMyRange rang3,stMyRange rang4){ stMyRank rank1 = m_Rank[0]; rank1.SetRange(rang1); stMyRank rank2 = m_Rank[1]; rank2.SetRange(rang2); stMyRank rank3 = m_Rank[2]; rank3.SetRange(rang3); stMyRank rank4 = m_Rank[3]; rank4.SetRange(rang4); MyNDArray res(*this,rank1,rank2,rank3,rank4); return res; } //override operator(), 使用MyNDarray arr; arr(1,2,3,4) T& operator()(iNum r1=0,iNum r2=0,iNum r3=0,iNum r4=0){ assert(m_pData != NULL); return m_pData[StoreIndex(r1,r2,r3,r4)]; } const T& operator()(iNum r1=0,iNum r2=0,iNum r3=0,iNum r4=0) const{ assert(m_pData != NULL); return m_pData[StoreIndex(r1,r2,r3,r4)]; } protected: iNum StoreIndex(iNum r1,iNum r2,iNum r3,iNum r4){//数据存储引擎,以行优先存储 assert(m_iRankCap > 0); if(m_isSlice){ r1 += m_Rank[0].Start; r2 += m_Rank[1].Start; r3 += m_Rank[2].Start; r4 += m_Rank[3].Start; }     //判断数据是否在合法的范围内 if(m_iRankCap >= 1) assert(r1 >= m_Rank[0].Start && r1 <= m_Rank[0].End); if(m_iRankCap >= 2) assert(r2 >= m_Rank[1].Start && r2 <= m_Rank[1].End); if(m_iRankCap >= 3) assert(r3 >= m_Rank[2].Start && r3 <= m_Rank[2].End); if(m_iRankCap >= 4) assert(r4 >= m_Rank[3].Start && r4 <= m_Rank[3].End); iNum index = 0; index = r1 + m_Rank[0].Rank * ( r2 + m_Rank[1].Rank * ( r3 + m_Rank[2].Rank * r4 ));//row based storage return index; } void Init(){//初始化内存 iNum size = 1; for (int i=0;i<m_iRankCap;i++) { size *= m_Rank[i].Rank; } if(m_iRankCap > 0 && size > 0) { m_pData = new T[size]; } else m_pData = NULL; } };
//重载输入>>和输出<<,用于数据读取和写入
//override operator >> template<class T> istream& operator >> (istream& myin,MyNDArray<T>& arr){ iNum r1,r2,r3,r4; r1 = arr.RankSize(1); r2 = arr.RankSize(2); r3 = arr.RankSize(3); r4 = arr.RankSize(4); iNum rankCap = arr.RankCap(); if(rankCap < 2) r2 = 1; if(rankCap < 3) r3 = 1; if(rankCap < 4) r4 = 1; for (int i4=0;i4<r4;i4++) { for (int i3=0;i3<r3;i3++) { for (int i2=0;i2<r2;i2++) { for (int i1=0;i1<r1;i1++) { myin >> arr(i1,i2,i3,i4);//调用重载的operator(),读取数据 } } } } return myin; } //override operator << template<class T> ostream& operator << (ostream& myout,MyNDArray<T>& arr){ iNum r1,r2,r3,r4; r1 = arr.RankSize(1); r2 = arr.RankSize(2); r3 = arr.RankSize(3); r4 = arr.RankSize(4); iNum rankCap = arr.RankCap(); if(rankCap < 2) r2 = 1; if(rankCap < 3) r3 = 1; if(rankCap < 4) r4 = 1; for (int i4=0;i4<r4;i4++) { myout<<"rank4:"<<i4+1<<endl; for (int i3=0;i3<r3;i3++) { myout<<"rank3:"<<i3+1<<endl; for (int i2=0;i2<r2;i2++) { for (int i1=0;i1<r1;i1++) { myout <<arr(i1,i2,i3,i4)<<"\t"; } myout<<"\n"; } } myout<<"\n"; } return myout; } #endif

使用方法:

int main(int argc, char* argv[])
{
    MyNDArray<int> ndarr(3,3,2,2);
    ifstream fin("ndarr.txt",ios::in);
    fin>>ndarr;
    cout<<ndarr;
    stMyRange range1(1,2);
    stMyRange range2(1,2);
    stMyRange range3(1,1);
    stMyRange range4(1,1);
    MyNDArray<int> slice = ndarr.Slice(range1,range2,range3,range4);
    cout<<"slice:"<<endl<<slice<<endl;
    cin>>ndarr;
    return 0;
}

ndarr.txt内如如下:

1 2 3
4 5 6
7 8 9
11 12 13
14 15 16
17 18 19
21 22 23
24 25 26
27 28 29
31 32 33
34 35 36
37 38 39

上述测试程序的输出结果为:

35 36

38 39

满足数据切片要求。

 

 参考资料:

C++运算符重载:http://www.cnblogs.com/lfsblack/archive/2012/10/01/2709476.html

python.ndarray简单使用

C++ prime中的模板章节

posted @ 2017-02-26 14:11  小小鸟的大梦想  阅读(1859)  评论(0编辑  收藏  举报