联合项目
https://github.com/loki-keroro/SAMbase_segmentation?tab=readme-ov-file


模型会根据不同的提示文本,生成不同的掩码,可修改main.py中的category_cfg变量,自定义提示文本。
- landcover_prompts 为地物分类的提示,在全景图中场景下一般用于分割区域连续或新增的类别
- cityobject_prompts 作为实例分割的提示,在全景图中场景下一般用于图像内区域不连续的对象类别
- landcover_prompts_cn和cityobject_prompts_cn为每个类别的中文含义
category_cfg = {
"landcover_prompts": ['building', 'low vegetation', 'tree', 'river', 'shed', 'road', 'lake', 'bare soil'],
"landcover_prompts_cn": ['建筑', '低矮植被', '树木', '河流', '棚屋', '道路', '湖泊', '裸土'],
"cityobject_prompts": ['car', 'truck', 'bus', 'train', 'ship', 'boat'],
"cityobject_prompts_cn": ['轿车', '卡车', '巴士', '列车', '船(舰)', '船(舶)']
}
from inference import PSAM
# 模型配置文件和权重文件
model_cfg = {
"DINO_WEIGHT_PATH": "weights/GSA_weights/groundingdino_swinb_cogcoor.pth",
"DINO_CFG_PATH": "groundingdino/config/GroundingDINO_SwinB.py",
"SAM_WEIGHT_PATH": "weights/GSA_weights/sam_vit_h_4b8939.pth",
"CLIP_WEIGHT_DIR": "weights/CLIP_weights/"
}
# prompts提示,可自定义类别列表
# 模型会根据不同的prompts提示,生成不同的掩码
# category_cfg = {
# "landcover_prompts": ['building', 'low vegetation', 'tree', 'water', 'shed', 'road', 'lake', 'bare soil',],
# "landcover_prompts_cn": ['建筑', '低矮植被', '树木', '水体', '棚屋', '道路', '湖泊', '裸土'],
# "cityobject_prompts": ['car', 'truck', 'bus', 'train', 'ship', 'boat'],
# "cityobject_prompts_cn": ['轿车', '卡车', '巴士', '列车', '船(舰)', '船(舶)']
# }
category_cfg = {
"landcover_prompts": [ 'building', 'water', 'tree', 'road','shed', 'cropland','grassland', 'Agricultural Fields','bare soil'],
"landcover_prompts_cn": ['建筑', '水体', '树木', '道路', '棚屋', '农田', '草地', '农用地','裸土'],
"cityobject_prompts": ['car', 'truck','train'],
"cityobject_prompts_cn": ['轿车', '货车','火车']
}
gpus = ["1"]
# matplotlib使用中文绘制
cn_style = False # 是否使用中文
font_style_path = '/usr/share/fonts/wqy-microhei/wqy-microhei.ttc' # 中文字体路径,可通过fc-list命令查看系统中所安装的字体
if __name__ == "__main__":
psam = PSAM(model_cfg, category_cfg, gpus)
# img_path = "/home/piesat/data/无人机全景图/panorama01-04/match_imgs/CD_dataset/01->03/A_B/A/100_right_0_1_hw(2701,672).png"
# img_path = "/home/piesat/media/ljh/pycharm_project__ljh/panorama_sam/photos/c1.png"
file_path = '/home/piesat/data/无人机全景图/panorama01-04/match_imgs/CD_dataset/cwptys_tmp/A'
save_path = '/home/piesat/media/ljh/pycharm_project__ljh/panorama_sam/photos/croplands/'
import os
files = []
for root, dirs, filenames in os.walk(file_path):
for filename in filenames:
in_img_path = os.path.join(root, filename)
out_img_path = os.path.join(save_path, filename)
psam.load_image(in_img_path)
panoptic_inds = psam.generate_panoptic_mask()
psam.plt_draw_image(cn_style, font_style_path, out_img_path)
print(panoptic_inds.shape) # panoptic_inds:单通道掩码图像
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import torch
from utils.data_utils import generate_color_list
from utils.load_models import load_clip_model, load_dino_model, load_sam_model
from utils.func_utils import dino_detection, sam_masks_from_dino_boxes, clipseg_segmentation, \
clip_and_shrink_preds, sample_points_based_on_preds, sam_mask_from_points, preds_to_semantic_inds
class PSAM(object):
def __init__(self, model_cfg, category_cfg, gpu_ids):
# 初始化GroundingDINO、SAM、CLIPSeg模型
self.device = torch.device("cuda:%s" % gpu_ids[0] if torch.cuda.is_available() and len(gpu_ids) > 0 else "cpu")
self.groundingdino_model = load_dino_model(model_cfg["DINO_CFG_PATH"], model_cfg["DINO_WEIGHT_PATH"], self.device)
self.sam_predictor = load_sam_model(model_cfg["SAM_WEIGHT_PATH"], self.device)
self.clipseg_processor, self.clipseg_model = load_clip_model(model_cfg["CLIP_WEIGHT_DIR"], self.device)
self.landcover_categories = category_cfg["landcover_prompts"]
self.cityobject_categories = category_cfg["cityobject_prompts"]
self.category_names = ["background"] + self.landcover_categories + self.cityobject_categories
self.category_name_to_id = {
category_name: i for i, category_name in enumerate(self.category_names)
}
self.category_id_to_name = {
i: category_name for i, category_name in enumerate(self.category_names)
}
self.color_map = generate_color_list(len(self.category_names))
self.landcover_categories_cn = category_cfg["landcover_prompts_cn"]
self.cityobject_categories_cn = category_cfg["cityobject_prompts_cn"]
self.category_names_cn = ["背景"] + self.landcover_categories_cn + self.cityobject_categories_cn
self.category_id_to_name_cn = {
i: category_name for i, category_name in enumerate(self.category_names_cn)
}
def load_image(self, img_path):
# 读取图像并进行SAM的图像编码
image = Image.open(img_path)
self.image = image.convert("RGB")
self.image_array = np.asarray(self.image)
self.sam_predictor.set_image(self.image_array)
def generate_panoptic_mask(self, dino_box_threshold=0.2,
dino_text_threshold=0.20,
segmentation_background_threshold=0.1,
shrink_kernel_size=10,
num_samples_factor=300
):
# 1.基于DINO的城市目标检测,并结合SAM进行分割
cityobject_category_ids = []
cityobject_masks = torch.empty(0)
cityobject_boxes = []
if len(self.cityobject_categories) > 0:
cityobject_boxes, cityobject_category_ids, _ = dino_detection(
self.groundingdino_model,
self.image,
self.cityobject_categories,
self.category_name_to_id,
dino_box_threshold,
dino_text_threshold,
self.device,
)
if len(cityobject_boxes) > 0:
cityobject_masks = sam_masks_from_dino_boxes(
self.sam_predictor, self.image_array, cityobject_boxes, self.device
)
# 2.基于CLIP的地物分类,并结合SAM进行分割
if len(self.landcover_categories) > 0:
clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
self.clipseg_processor,
self.clipseg_model,
self.image,
self.landcover_categories,
segmentation_background_threshold,
self.device,
)
clipseg_semantic_inds_without_cityobject = clipseg_semantic_inds.clone()
if len(cityobject_boxes) > 0:
combined_cityobject_mask = torch.any(cityobject_masks, dim=0)
clipseg_semantic_inds_without_cityobject[combined_cityobject_mask[0]] = 0
clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
clipseg_semantic_inds_without_cityobject,
clipseg_preds,
shrink_kernel_size,
len(self.landcover_categories) + 1,
)
sam_preds = torch.zeros_like(clipsed_clipped_preds)
for i in range(clipsed_clipped_preds.shape[0]):
clipseg_pred = clipsed_clipped_preds[i]
num_samples = int(relative_sizes[i] * num_samples_factor)
if num_samples == 0:
continue
points = sample_points_based_on_preds(
clipseg_pred.cpu().numpy(), num_samples
)
if len(points) == 0:
continue
pred = sam_mask_from_points(self.sam_predictor, self.image_array, points)
sam_preds[i] = pred
sam_semantic_inds = preds_to_semantic_inds(
sam_preds, segmentation_background_threshold
)
# 3.结合城市目标和地物分类的掩码结果
if len(self.landcover_categories) > 0:
# 进行开闭运算
self.panoptic_inds = sam_semantic_inds.clone().cpu().numpy().astype(np.uint8)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
self.panoptic_inds = cv2.morphologyEx(self.panoptic_inds, cv2.MORPH_OPEN, kernel)
self.panoptic_inds = cv2.morphologyEx(self.panoptic_inds, cv2.MORPH_CLOSE, kernel)
else:
self.panoptic_inds = np.zeros((self.image_array.shape[0], self.image_array.shape[1]), dtype=np.uint8)
for mask_cid in range(cityobject_masks.shape[0]):
ind = cityobject_category_ids[mask_cid]
mask_bool = cityobject_masks[mask_cid].squeeze(dim=0).cpu().numpy()
self.panoptic_inds[mask_bool] = ind
return self.panoptic_inds
def plt_draw_image(self, cn_style=False, font_style_path=None, save_file_path =None):
# 是否使用中文显示
if cn_style==True and font_style_path is not None:
cn_style = True
font = FontProperties(fname=font_style_path)
id_to_name = self.category_id_to_name_cn
else:
cn_style = False
font = FontProperties()
id_to_name = self.category_id_to_name
# 使用unique函数和return_counts参数计算每种类别的占用像素个数
unique_values, counts = np.unique(self.panoptic_inds, return_counts=True)
count_map = {}
bar_colors = [] # 储存每种类别的颜色
for value, count in zip(unique_values, counts):
count_map[id_to_name[value]] = count
r, g, b = self.color_map[value]
r = r / 255
g = g / 255
b = b / 255
bar_colors.append((r, g, b, 1.0))
x = list(count_map.keys())
y = list(count_map.values())
# 创建子图
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))
# 绘制原图
axes[0, 0].imshow(self.image)
# 绘制掩码图
cm = [list(t) for t in self.color_map]
cm = np.array(cm).astype('uint8')
label_img = cm[self.panoptic_inds]
axes[0, 1].imshow(Image.fromarray(label_img))
# 绘制合并图
draw_image = cv2.addWeighted(np.array(self.image), 0.7, label_img, 0.3, 0)
axes[1, 0].imshow(Image.fromarray(draw_image))
# 绘制柱状图
axes[1, 1].bar(range(len(x)), y, label=x, color=bar_colors)
# 添加数值标签
for a, b in zip(range(len(x)), y):
axes[1, 1].text(a, b, b, ha='center', va='bottom', fontproperties=font)
# 添加标题和横纵坐标含义
if cn_style:
axes[1, 1].set_title('统计每个类别占用的像素', fontproperties=font)
axes[1, 1].set_xlabel('类别', fontproperties=font)
axes[1, 1].set_ylabel('像素', fontproperties=font)
else:
axes[1, 1].set_title('Pixel Count', fontproperties=font)
axes[1, 1].set_xlabel('Category', fontproperties=font)
axes[1, 1].set_ylabel('Pixel', fontproperties=font)
axes[1, 1].set_xticklabels([])
# 添加图例
axes[1, 1].legend(prop=font)
# 调整子图间距
plt.subplots_adjust(wspace=0.15, hspace=0.2)
#保存图像
plt.savefig(save_file_path)
# # 显示图形
# plt.show()
浙公网安备 33010602011771号