c++实现unet

#include<torch/torch.h>
#include<iostream>
#include<vector>
#include<cassert>
#include<stdlib.h>
#include<unordered_map>
#include<fstream>
class double_conv:public torch::nn::Module
{
    public:
        torch::nn::Conv2d conv1,conv2;
        torch::nn::BatchNorm bn1,bn2;
        int in_ch,out_ch;
    public:
        double_conv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv1(torch::nn::Conv2dOptions(in_ch,out_ch,3).padding(1)),bn1(out_ch),
                                       conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch)
        {
            register_module("conv1",conv1);
            register_module("conv2",conv2);
            register_module("bn1",bn1);
            register_module("bn2",bn2);
        }
        torch::Tensor forward(torch::Tensor x)
        {
            x = conv1->forward(x);
            x = bn1->forward(x);
            x = torch::relu(x);
            x = conv2->forward(x);
            x = bn2->forward(x);
            x = torch::relu(x);
            return x;
        }
};
class inconv:public torch::nn::Module
{
    public:
        int in_ch,out_ch;
    public:
        inconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){}
        torch::Tensor forward(torch::Tensor x)
        {
             double_conv dc(in_ch,out_ch);
             x = dc.forward(x);
             return x;
        }
};
class down:public torch::nn::Module
{
    public:
        int in_ch,out_ch;
    public:
        down(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){}
        torch::Tensor forward(torch::Tensor x)
        {
            x = torch::max_pool2d(x,2);
            double_conv dc(in_ch,out_ch);
            x = dc.forward(x);
            return x;
        }
};
class up:public torch::nn::Module
{
    public:
        int in_ch,out_ch;
        torch::nn::Conv2d upconv;
        torch::nn::Conv2d conv1,conv2;
        torch::nn::BatchNorm bn1,bn2;
        torch::Tensor x;
    public:
        up(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),upconv(torch::nn::Conv2dOptions(in_ch,out_ch,4).padding(1).stride(2).transposed(new bool(true))),
                                 conv1(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn1(out_ch),conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch)
        {
            register_module("upconv",upconv);
            register_module("conv1",conv2);
            register_module("conv2",conv2);
            register_module("bn1",bn1);
            register_module("bn2",bn2);
        }
        torch::Tensor forward(torch::Tensor x1,torch::Tensor x2)
        {
            x = upconv->forward(x1);
            x = torch::cat({x,x2},1);
            double_conv dc(x.size(1),out_ch);
            x = dc.forward(x);
            //x = conv1->forward(x);
            //x = bn1->forward(x);
            //x = torch::relu(x);
            //x = conv2->forward(x);
            //x = bn2->forward(x);
            //x = torch::relu(x);
            return x;
        }
};
class outconv:public torch::nn::Module
{
    public:
        int in_ch,out_ch;
        torch::nn::Conv2d conv;
    public:
        outconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv(torch::nn::Conv2dOptions(in_ch,out_ch,1).padding(0))
        {
            register_module("conv",conv);
        }
        torch::Tensor forward(torch::Tensor x)
        {
            return conv->forward(x);
        }
};
class unet:public torch::nn::Module
{
    public:
        int n_ch,n_class;
        inconv *iconv= new inconv(n_ch,64);
        down *down1= new down(64,256);
        down *down2= new down(256,512);
        down *down3= new down(512,512);
        down *down4= new down(512,512);
        up *up1= new up(512,256);
        up *up2= new up(256,128);
        up *up3= new up(128,64);
        up *up4= new up(64,64);
        outconv *oconv= new outconv(64,n_class);
        torch::Tensor x1,x2,x3,x4,x5;
    public:
        unet(int n_ch,int n_class):n_ch(n_ch),n_class(n_class){}
        torch::Tensor forward(torch::Tensor x)
        {
           x1 = iconv->forward(x);
           x2 = down1->forward(x1);
           x3 = down2->forward(x2);
           x4 = down3->forward(x3);
           x5 = down4->forward(x4);
           x = up1->forward(x5,x4);
           x = up2->forward(x,x3);
           x = up3->forward(x,x2);
           x = up4->forward(x,x1);
           x = oconv->forward(x);
           return x;
        }
};
std::vector<float> Tokenize(const std::string& str,const std::string& delimiters)
{
    std::vector<float> tokens;
    std::string::size_type lastPos = str.find_first_not_of(delimiters, 0);
    std::string::size_type pos     = str.find_first_of(delimiters, lastPos);
    while (std::string::npos != pos || std::string::npos != lastPos)
    {
        tokens.push_back(std::atof(str.substr(lastPos, pos - lastPos).c_str()));
        lastPos = str.find_first_not_of(delimiters, pos);
        pos = str.find_first_of(delimiters, lastPos);
    }
    return tokens;
}
std::vector<std::vector<float>> readTxt(std::string file)
{
    std::ifstream infile;
    infile.open(file.data());
    assert(infile.is_open());
    std::string s;
    std::vector<float> vec;
    std::vector<std::vector<float>> res;
    while(getline(infile,s))
    {
        std::string tt= static_cast<std::string>(s);
        vec = Tokenize(tt, " ");
        res.push_back(vec);
    }
    infile.close();
    std::cout<<"gdood"<<std::endl;
    return res;
}
torch::Tensor float2TensorLabel()
{
    static float tt[2478][3125]={0};
    //memset(tt,0,sizeof(tt));
    std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/LabelData.txt");
    int ch = vec.size();
    int len = vec[0].size();
    for(int i=0;i<ch;i++)
    {
        for(int j=0;j<len;j++)
        {
            tt[i][j]=vec[i][j];
        }
    }
    torch::Tensor tmask = torch::CPU(torch::kFloat).tensorFromBlob(tt,{2478,3125});
    return tmask;
}
torch::Tensor float2TensorData()
{
    static float tt[7][2478*3125] = {0};
    std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/ImageData.txt");
    int ch = vec.size();
    int len = vec[0].size();
    for(int i=0;i<ch;i++)
    {
        for(int j=0;j<len;j++)
        {
            tt[i][j]=vec[i][j];
        }
    }
    torch::Tensor tdata = torch::CPU(torch::kFloat).tensorFromBlob(tt,{7,2478,3125});
    return tdata;
}
int imgH=256;
int imgW=256;
torch::Tensor RandData(torch::Tensor data,int hight,int width)
{
    //torch::Tensor datat = torch::squeeze(data);
    torch::Tensor tmp = torch::zeros({7,imgH,imgW});
    for(int i=hight;i<hight+imgH;i++)
    {
        for(int j=width;j<width+imgW;j++)
        {
            for(int k=0;k<7;k++)
            {
                tmp[k][i-hight][j-width]=data[k][i][j];
            }
        }
    }
    return tmp;
}
torch::Tensor RandMask(torch::Tensor label, int hight,int width)
{
    torch::Tensor tmp = torch::zeros({imgH,imgW});
    for(int i=hight;i<hight+imgH;i++)
    {
        for(int j=width;j<width+imgW;j++)
        {
            tmp[i-hight][j-width]=label[i][j];
        }
    }
   return tmp;
}
std::vector<torch::Tensor> DataLoader(torch::Tensor data,torch::Tensor label,int batch_size)
{
    int imghight = data.size(1);
    int imgwidth = data.size(2);
    int randhight,randwidth;
    torch::Tensor resdata = torch::zeros({batch_size,7,imgH,imgW});
    torch::Tensor reslabel = torch::zeros({batch_size,imgH,imgW});
    for(int i=0;i<batch_size;i++)
    {
        randhight = rand()%(imghight-imgH-1);
        randwidth = rand()%(imgwidth-imgW-1);
        resdata[i] = RandData(data,randhight,randwidth);
        reslabel[i] = RandMask(label,randhight,randwidth);
    }
    return {resdata,reslabel};
}
torch::autograd::Variable Get_predData(torch::autograd::Variable data)
{
    //torch::autograd::Variable datat = torch::unsqueeze(data,0);
    torch::autograd::Variable tmp = torch::zeros({7,imgH,imgW});
    for(int i=500;i<756;i++)
    {
        for(int j=500;j<756;j++)
        {
            for(int k=0;k<7;k++)
            {
                tmp[k][i-500][j-500]=data[k][i][j];
            }
        }
    }
    return torch::unsqueeze(tmp,0);
}
void write2Txt(torch::autograd::Variable data)
{
    std::ofstream fout("tresult.txt");
    for(int i=0;i<data.size(0);i++)
    {
        for(int j=0;j<data.size(1);j++)
        {
            fout<<data[i][j]<<std::endl;
        }
    }
    fout.close();
}
void saveModel(std::vector<torch::Tensor> weights,std::vector<std::string> key)
{
    std::ofstream fout("unet.txt");
    //std::unordered_map<std::string,torch::Tensor> mp;
    for(int i=0;i<weights.size();i++)
    {
        fout<<key[i]<<std::endl;
        fout<<weights[i]<<std::endl;
    }
    fout.close();
}
void trainConvNet(unet model)
{
    torch::optim::SGD optimizer(model.parameters(),/*lr=*/0.01);
    torch::Tensor pred;
    std::cout<<"load data ......"<<std::endl;
    torch::autograd::Variable data = torch::autograd::make_variable(float2TensorData());
    torch::autograd::Variable label = torch::autograd::make_variable(float2TensorLabel());
    std::cout<<"done!!"<<std::endl;
    torch::Tensor train_data,train_label;
    std::vector<torch::Tensor> vecdata;
    for(int epoch=0;epoch<20;epoch++)
    {
        vecdata = DataLoader(data,label,2);
        std::cout<<"vecdata after done!!"<<std::endl;
        train_data = vecdata[0];
        std::cout<<"train_data after done"<<std::endl;
        train_label = vecdata[1];
        std::cout<<train_label.size(0)<<std::endl;
        std::cout<<"train_label after done"<<std::endl;
        pred = model.forward(train_data);
        auto loss = torch::nll_loss2d(pred,torch::_cast_Long(train_label));//torch::_cast_Long()
        std::cout<<"the loss is"<<loss<<std::endl;
        optimizer.zero_grad();
        loss.backward();
        optimizer.step();
    }
    std::vector<torch::Tensor> vecValue;
    std::vector<std::string> vecKey;
    torch::nn::ParameterCursor tt = model.parameters();
    for(auto it=tt.begin();it!=tt.end();it++)
    {
        vecValue.push_back((*it).value);
        vecKey.push_back((*it).key);
    }
    saveModel(vecValue,vecKey);
    torch::autograd::Variable predData = Get_predData(data);
    torch::autograd::Variable fl = model.forward(predData);
    torch::autograd::Variable result = torch::squeeze(fl);
    torch::autograd::Variable rt = result.argmax(0);
    std::cout<<rt.size(0)<<std::endl;
    std::cout<<rt.size(1)<<std::endl;
    write2Txt(rt);
}
int main()
{
    unet net(7,2);
    trainConvNet(net);
    return 0;
}

 

posted @ 2018-10-12 15:19  semen  阅读(1545)  评论(1编辑  收藏  举报