#include<torch/torch.h>
#include<iostream>
#include<vector>
#include<cassert>
#include<stdlib.h>
#include<unordered_map>
#include<fstream>
class double_conv:public torch::nn::Module
{
public:
torch::nn::Conv2d conv1,conv2;
torch::nn::BatchNorm bn1,bn2;
int in_ch,out_ch;
public:
double_conv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv1(torch::nn::Conv2dOptions(in_ch,out_ch,3).padding(1)),bn1(out_ch),
conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch)
{
register_module("conv1",conv1);
register_module("conv2",conv2);
register_module("bn1",bn1);
register_module("bn2",bn2);
}
torch::Tensor forward(torch::Tensor x)
{
x = conv1->forward(x);
x = bn1->forward(x);
x = torch::relu(x);
x = conv2->forward(x);
x = bn2->forward(x);
x = torch::relu(x);
return x;
}
};
class inconv:public torch::nn::Module
{
public:
int in_ch,out_ch;
public:
inconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){}
torch::Tensor forward(torch::Tensor x)
{
double_conv dc(in_ch,out_ch);
x = dc.forward(x);
return x;
}
};
class down:public torch::nn::Module
{
public:
int in_ch,out_ch;
public:
down(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){}
torch::Tensor forward(torch::Tensor x)
{
x = torch::max_pool2d(x,2);
double_conv dc(in_ch,out_ch);
x = dc.forward(x);
return x;
}
};
class up:public torch::nn::Module
{
public:
int in_ch,out_ch;
torch::nn::Conv2d upconv;
torch::nn::Conv2d conv1,conv2;
torch::nn::BatchNorm bn1,bn2;
torch::Tensor x;
public:
up(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),upconv(torch::nn::Conv2dOptions(in_ch,out_ch,4).padding(1).stride(2).transposed(new bool(true))),
conv1(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn1(out_ch),conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch)
{
register_module("upconv",upconv);
register_module("conv1",conv2);
register_module("conv2",conv2);
register_module("bn1",bn1);
register_module("bn2",bn2);
}
torch::Tensor forward(torch::Tensor x1,torch::Tensor x2)
{
x = upconv->forward(x1);
x = torch::cat({x,x2},1);
double_conv dc(x.size(1),out_ch);
x = dc.forward(x);
//x = conv1->forward(x);
//x = bn1->forward(x);
//x = torch::relu(x);
//x = conv2->forward(x);
//x = bn2->forward(x);
//x = torch::relu(x);
return x;
}
};
class outconv:public torch::nn::Module
{
public:
int in_ch,out_ch;
torch::nn::Conv2d conv;
public:
outconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv(torch::nn::Conv2dOptions(in_ch,out_ch,1).padding(0))
{
register_module("conv",conv);
}
torch::Tensor forward(torch::Tensor x)
{
return conv->forward(x);
}
};
class unet:public torch::nn::Module
{
public:
int n_ch,n_class;
inconv *iconv= new inconv(n_ch,64);
down *down1= new down(64,256);
down *down2= new down(256,512);
down *down3= new down(512,512);
down *down4= new down(512,512);
up *up1= new up(512,256);
up *up2= new up(256,128);
up *up3= new up(128,64);
up *up4= new up(64,64);
outconv *oconv= new outconv(64,n_class);
torch::Tensor x1,x2,x3,x4,x5;
public:
unet(int n_ch,int n_class):n_ch(n_ch),n_class(n_class){}
torch::Tensor forward(torch::Tensor x)
{
x1 = iconv->forward(x);
x2 = down1->forward(x1);
x3 = down2->forward(x2);
x4 = down3->forward(x3);
x5 = down4->forward(x4);
x = up1->forward(x5,x4);
x = up2->forward(x,x3);
x = up3->forward(x,x2);
x = up4->forward(x,x1);
x = oconv->forward(x);
return x;
}
};
std::vector<float> Tokenize(const std::string& str,const std::string& delimiters)
{
std::vector<float> tokens;
std::string::size_type lastPos = str.find_first_not_of(delimiters, 0);
std::string::size_type pos = str.find_first_of(delimiters, lastPos);
while (std::string::npos != pos || std::string::npos != lastPos)
{
tokens.push_back(std::atof(str.substr(lastPos, pos - lastPos).c_str()));
lastPos = str.find_first_not_of(delimiters, pos);
pos = str.find_first_of(delimiters, lastPos);
}
return tokens;
}
std::vector<std::vector<float>> readTxt(std::string file)
{
std::ifstream infile;
infile.open(file.data());
assert(infile.is_open());
std::string s;
std::vector<float> vec;
std::vector<std::vector<float>> res;
while(getline(infile,s))
{
std::string tt= static_cast<std::string>(s);
vec = Tokenize(tt, " ");
res.push_back(vec);
}
infile.close();
std::cout<<"gdood"<<std::endl;
return res;
}
torch::Tensor float2TensorLabel()
{
static float tt[2478][3125]={0};
//memset(tt,0,sizeof(tt));
std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/LabelData.txt");
int ch = vec.size();
int len = vec[0].size();
for(int i=0;i<ch;i++)
{
for(int j=0;j<len;j++)
{
tt[i][j]=vec[i][j];
}
}
torch::Tensor tmask = torch::CPU(torch::kFloat).tensorFromBlob(tt,{2478,3125});
return tmask;
}
torch::Tensor float2TensorData()
{
static float tt[7][2478*3125] = {0};
std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/ImageData.txt");
int ch = vec.size();
int len = vec[0].size();
for(int i=0;i<ch;i++)
{
for(int j=0;j<len;j++)
{
tt[i][j]=vec[i][j];
}
}
torch::Tensor tdata = torch::CPU(torch::kFloat).tensorFromBlob(tt,{7,2478,3125});
return tdata;
}
int imgH=256;
int imgW=256;
torch::Tensor RandData(torch::Tensor data,int hight,int width)
{
//torch::Tensor datat = torch::squeeze(data);
torch::Tensor tmp = torch::zeros({7,imgH,imgW});
for(int i=hight;i<hight+imgH;i++)
{
for(int j=width;j<width+imgW;j++)
{
for(int k=0;k<7;k++)
{
tmp[k][i-hight][j-width]=data[k][i][j];
}
}
}
return tmp;
}
torch::Tensor RandMask(torch::Tensor label, int hight,int width)
{
torch::Tensor tmp = torch::zeros({imgH,imgW});
for(int i=hight;i<hight+imgH;i++)
{
for(int j=width;j<width+imgW;j++)
{
tmp[i-hight][j-width]=label[i][j];
}
}
return tmp;
}
std::vector<torch::Tensor> DataLoader(torch::Tensor data,torch::Tensor label,int batch_size)
{
int imghight = data.size(1);
int imgwidth = data.size(2);
int randhight,randwidth;
torch::Tensor resdata = torch::zeros({batch_size,7,imgH,imgW});
torch::Tensor reslabel = torch::zeros({batch_size,imgH,imgW});
for(int i=0;i<batch_size;i++)
{
randhight = rand()%(imghight-imgH-1);
randwidth = rand()%(imgwidth-imgW-1);
resdata[i] = RandData(data,randhight,randwidth);
reslabel[i] = RandMask(label,randhight,randwidth);
}
return {resdata,reslabel};
}
torch::autograd::Variable Get_predData(torch::autograd::Variable data)
{
//torch::autograd::Variable datat = torch::unsqueeze(data,0);
torch::autograd::Variable tmp = torch::zeros({7,imgH,imgW});
for(int i=500;i<756;i++)
{
for(int j=500;j<756;j++)
{
for(int k=0;k<7;k++)
{
tmp[k][i-500][j-500]=data[k][i][j];
}
}
}
return torch::unsqueeze(tmp,0);
}
void write2Txt(torch::autograd::Variable data)
{
std::ofstream fout("tresult.txt");
for(int i=0;i<data.size(0);i++)
{
for(int j=0;j<data.size(1);j++)
{
fout<<data[i][j]<<std::endl;
}
}
fout.close();
}
void saveModel(std::vector<torch::Tensor> weights,std::vector<std::string> key)
{
std::ofstream fout("unet.txt");
//std::unordered_map<std::string,torch::Tensor> mp;
for(int i=0;i<weights.size();i++)
{
fout<<key[i]<<std::endl;
fout<<weights[i]<<std::endl;
}
fout.close();
}
void trainConvNet(unet model)
{
torch::optim::SGD optimizer(model.parameters(),/*lr=*/0.01);
torch::Tensor pred;
std::cout<<"load data ......"<<std::endl;
torch::autograd::Variable data = torch::autograd::make_variable(float2TensorData());
torch::autograd::Variable label = torch::autograd::make_variable(float2TensorLabel());
std::cout<<"done!!"<<std::endl;
torch::Tensor train_data,train_label;
std::vector<torch::Tensor> vecdata;
for(int epoch=0;epoch<20;epoch++)
{
vecdata = DataLoader(data,label,2);
std::cout<<"vecdata after done!!"<<std::endl;
train_data = vecdata[0];
std::cout<<"train_data after done"<<std::endl;
train_label = vecdata[1];
std::cout<<train_label.size(0)<<std::endl;
std::cout<<"train_label after done"<<std::endl;
pred = model.forward(train_data);
auto loss = torch::nll_loss2d(pred,torch::_cast_Long(train_label));//torch::_cast_Long()
std::cout<<"the loss is"<<loss<<std::endl;
optimizer.zero_grad();
loss.backward();
optimizer.step();
}
std::vector<torch::Tensor> vecValue;
std::vector<std::string> vecKey;
torch::nn::ParameterCursor tt = model.parameters();
for(auto it=tt.begin();it!=tt.end();it++)
{
vecValue.push_back((*it).value);
vecKey.push_back((*it).key);
}
saveModel(vecValue,vecKey);
torch::autograd::Variable predData = Get_predData(data);
torch::autograd::Variable fl = model.forward(predData);
torch::autograd::Variable result = torch::squeeze(fl);
torch::autograd::Variable rt = result.argmax(0);
std::cout<<rt.size(0)<<std::endl;
std::cout<<rt.size(1)<<std::endl;
write2Txt(rt);
}
int main()
{
unet net(7,2);
trainConvNet(net);
return 0;
}