import os
import mxnet as mx
from mxnet import image, gpu
from PIL import Image
import gluoncv
from gluoncv.data.transforms.presets.segmentation import test_transform
from gluoncv.utils.viz import get_color_pallete,plot_image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import pandas as pd
import shutil
#确定处理平台
ctx = mx.gpu(0)#若不是gpu版本需要改为cpu(0)
#下载在cityscape上的pspnet预训练模型
model = gluoncv.model_zoo.get_model('psp_resnet101_citys', ctx=ctx, pretrained=True)#此处的psp_resnet101_citys可以修改为其他的预训练模型
#准备文件路径
for n in range(1,12854):#这里对每一个坐标文件夹进行遍历,对文件夹内的图像进行处理
print('The No.{} coordinate is handling'.format(n))
general_path='G:/points/from/{}'.format(n)#替换为自己的路径
try:
general_file=os.walk(general_path)
except FileNotFoundError:
print("No.{} coordinate is missing".format(n))
img_names=[]
lonlat_names=[]
for root, dirs, files in general_file:
img_names[:] = [f for f in files if f.endswith(".png")]
lonlat_names[:] = [g for g in files if g.endswith(".txt")]
for lonlat in lonlat_names:
save_path='G:/points/target/{}'.format(n)#替换为自己的路径
os.mkdir(save_path)
lonlat_path=general_path + '/' + lonlat
#将坐标信息文件复制到目标文件夹中继续储存
shutil.copy(lonlat_path, save_path)
#对图像进行处理
for img_path_ in img_names:
df = pd.DataFrame(columns=['id','lng','lat','heading','road','sidewalk','building','wall','fence',
'pole','traffic light','traffic sign','vegetation','terrain','sky',
'person','rider','car','truck','bus','train','motorcycle','bicycle'])
#读取图片并分割,返回的pred后续存入表格
img_num=img_path_[0]
img_path='G:/points/from/{}/{}.png'.format(n,img_num)#替换为自己的路径
img = image.imread(img_path)
img = test_transform(img,ctx=ctx)
output = model.predict(img)
predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
col_map = {0:'road', 1:'sidewalk', 2:'building', 3:'wall', 4:'fence', 5:'pole', 6:'traffic light',
7:'traffic sign', 8:'vegetation', 9:'terrain', 10:'sky', 11:'person', 12:'rider',
13:'car', 14:'truck', 15:'bus', 16:'train', 17:'motorcycle', 18:'bicycle'}
pred = []
for i in range(19):
pred.append((len(predict[predict==i])/(predict.shape[0]*predict.shape[1])))
pred = pd.Series(pred).rename(col_map)
#将结果存入表格
data_i = pd.Series({'id':img_num,}).append(pred)
df = pd.concat([df, pd.DataFrame(data_i).T], axis=0, join='outer', ignore_index=True)
print('---------Segmentation Is Ok--------')
df.to_csv(save_path + "/img_seg_csv{}.csv".format(img_num))
#将分割结果可视化并储存(可选)
mask = get_color_pallete(predict, 'citys')
base = Image.open(img_path)
plt.figure(figsize=(10,5))
plt.imshow(base)
plt.imshow(mask,alpha=0.65)
plt.axis('off')
plt.savefig(save_path + "/img_seg_jpg{}.png".format(img_num),dpi=200,bbox_inches='tight')
plt.close()