基于Python+MXnet预训练模型的街景图像语义分割代码

import os
import mxnet as mx
from mxnet import image, gpu
from PIL import Image
import gluoncv
from gluoncv.data.transforms.presets.segmentation import test_transform
from gluoncv.utils.viz import get_color_pallete,plot_image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import pandas as pd
import shutil



#确定处理平台
ctx = mx.gpu(0)#若不是gpu版本需要改为cpu(0)


#下载在cityscape上的pspnet预训练模型
model = gluoncv.model_zoo.get_model('psp_resnet101_citys', ctx=ctx, pretrained=True)#此处的psp_resnet101_citys可以修改为其他的预训练模型


#准备文件路径
for n in range(1,12854):#这里对每一个坐标文件夹进行遍历,对文件夹内的图像进行处理
    print('The No.{} coordinate is handling'.format(n))
    general_path='G:/points/from/{}'.format(n)#替换为自己的路径
    try:
        general_file=os.walk(general_path)
    except FileNotFoundError:
        print("No.{} coordinate is missing".format(n))
    img_names=[]
    lonlat_names=[]
    for root, dirs, files in general_file:
        img_names[:] = [f for f in files if f.endswith(".png")]
        lonlat_names[:] = [g for g in files if g.endswith(".txt")]
        for lonlat in lonlat_names:
            save_path='G:/points/target/{}'.format(n)#替换为自己的路径
            os.mkdir(save_path)
            lonlat_path=general_path + '/' + lonlat
#将坐标信息文件复制到目标文件夹中继续储存
            shutil.copy(lonlat_path, save_path)


#对图像进行处理
        for img_path_ in img_names:
                df = pd.DataFrame(columns=['id','lng','lat','heading','road','sidewalk','building','wall','fence',
                                           'pole','traffic light','traffic sign','vegetation','terrain','sky',
                                           'person','rider','car','truck','bus','train','motorcycle','bicycle'])
                
                
                #读取图片并分割,返回的pred后续存入表格
                img_num=img_path_[0]
                img_path='G:/points/from/{}/{}.png'.format(n,img_num)#替换为自己的路径
                img = image.imread(img_path)
                img = test_transform(img,ctx=ctx)
                output = model.predict(img)
                predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
                col_map = {0:'road', 1:'sidewalk', 2:'building', 3:'wall', 4:'fence', 5:'pole', 6:'traffic light',
                               7:'traffic sign', 8:'vegetation', 9:'terrain', 10:'sky', 11:'person', 12:'rider',
                               13:'car', 14:'truck', 15:'bus', 16:'train', 17:'motorcycle', 18:'bicycle'}
                pred = []
                for i in range(19):
                    pred.append((len(predict[predict==i])/(predict.shape[0]*predict.shape[1])))
                pred = pd.Series(pred).rename(col_map)
                
                
                #将结果存入表格
                data_i = pd.Series({'id':img_num,}).append(pred)
                df = pd.concat([df, pd.DataFrame(data_i).T], axis=0, join='outer', ignore_index=True)
                print('---------Segmentation Is Ok--------')
                df.to_csv(save_path + "/img_seg_csv{}.csv".format(img_num))

                        
                
                #将分割结果可视化并储存(可选)
                mask = get_color_pallete(predict, 'citys')
                base = Image.open(img_path)
                plt.figure(figsize=(10,5))
                plt.imshow(base)
                plt.imshow(mask,alpha=0.65)
                plt.axis('off')
                plt.savefig(save_path + "/img_seg_jpg{}.png".format(img_num),dpi=200,bbox_inches='tight')
                plt.close()

  

posted @ 2023-11-07 21:00  Victooor_swd  阅读(219)  评论(0)    收藏  举报