EnsNet: Ensconce Text in the Wild 模型训练

参考网址:

$ https://github.com/HCIILAB/Scene-Text-Removal

环境配置

$ git clone https://github.com/HCIILAB/Scene-Text-Removal
$ https://files.pythonhosted.org/packages/b0/e3/0a7bf93413623ec5a1fa42eb3c89f88731a62155f22ca6b1abc8c67c28d3/mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
$ pip install mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
$ pip install mxnet_cu90-1.5.0-py2.py3-none-win_amd64.whl
$ 验证是否安装成功
Python
>> import mxnet
import 成功说明没有问题。

下载数据集

$ 目的是为了参照它的数据集整理我们自己的数据集。数据格式最好和它的一样,方便先跑通代码。

模型训练

$ 目的是为了参照它的数据集整理我们自己的数据集。数据格式最好和它的一样,方便先跑通代码。

python train.py --trainset_path=’dataset’ --checkpoint=’save_model’ --gpu=0 --lr=0.0002 --n_epoch=5000

网络调整

训练这个网络存在的问题是:

给出来的数据和给的数据读取方式,不匹配,或者至少我没有理解。

解压以后的数据格式为:

syn_train下面包含img和label两个文件夹。

实际读图像的代码如下所示:

class MyDataSet(Dataset):
def __init__(self, root, split, is_transform=False,is_train=True):
self.root = os.path.join(root, split)
self.is_transform = is_transform
self.img_paths = []
self._img_512 = os.path.join(root, split, 'train_512', '{}.png')
self._mask_512 = os.path.join(root, split, 'mask_512', '{}.png')
self._lbl_512 = os.path.join(root, split, 'train_512', '{}.png')
self._img_256 = os.path.join(root, split, 'train_256', '{}.png')
self._lbl_256 = os.path.join(root, split, 'train_256', '{}.png')
self._img_128 = os.path.join(root, split, 'train_128', '{}.png')
for fn in os.listdir(os.path.join(root, split, 'train_512')):
if len(fn) > 3 and fn[-4:] == '.png':
self.img_paths.append(fn[:-4])

def __len__(self):
return len(self.img_paths)

def __getitem__(self, idx):
img_path_512 = self._img_512.format(self.img_paths[idx])
img_path_256 = self._img_256.format(self.img_paths[idx])
img_path_128 = self._img_128.format(self.img_paths[idx])
lbl_path_256 = self._lbl_256.format(self.img_paths[idx])
mask_path_512 = self._mask_512.format(self.img_paths[idx])
lbl_path_512 = self._lbl_512.format(self.img_paths[idx])
img_arr_256 = mx.image.imread(img_path_256).astype(np.float32)/127.5 - 1
img_arr_512 = mx.image.imread(img_path_512).astype(np.float32)/127.5 - 1
img_arr_128 = mx.image.imread(img_path_128).astype(np.float32)/127.5 - 1
img_arr_512 = mx.image.imresize(img_arr_512, img_wd * 2, img_ht)
img_arr_in_512, img_arr_out_512 = [mx.image.fixed_crop(img_arr_512, 0, 0, img_wd, img_ht),
mx.image.fixed_crop(img_arr_512, img_wd, 0, img_wd, img_ht)]
if os.path.exists(mask_path_512):
mask_512 = mx.image.imread(mask_path_512)
else:
mask_512 = mx.image.imread(mask_path_512.replace(".png",'.jpg',1))
tep_mask_512 = nd.slice_axis(mask_512, axis=2, begin=0, end=1)/255
if self.is_transform:
imgs = [img_arr_out_512, img_arr_in_512, tep_mask_512,img_arr_256,img_arr_128]
imgs = random_horizontal_flip(imgs)
imgs = random_rotate(imgs)
img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_256,img_arr_128 = imgs[0], imgs[1], imgs[2], imgs[3],imgs[4]
img_arr_in_512, img_arr_out_512 = [nd.transpose(img_arr_in_512, (2,0,1)),
nd.transpose(img_arr_out_512, (2,0,1))]
img_arr_out_256 = nd.transpose(img_arr_256, (2,0,1))
img_arr_out_128 = nd.transpose(img_arr_128, (2,0,1))
tep_mask_512 = tep_mask_512.reshape(tep_mask_512.shape[0],tep_mask_512.shape[1],1)
tep_mask_512 = nd.transpose(tep_mask_512,(2,0,1))
return img_arr_out_512,img_arr_in_512,tep_mask_512,img_arr_out_256,img_arr_out_128
不匹配,实际我们没有这么多文件夹。
排查问题的过程:
posted @ 2019-12-10 18:43  皮卡皮卡妞  阅读(738)  评论(0)    收藏  举报