• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
MKT-porter
博客园    首页    新随笔    联系   管理    订阅  订阅
语言+目标检测+粗略分割模型 联合+ sag精细化分割 = 快速精细带类别的分割

 

 

联合项目

 

https://github.com/loki-keroro/SAMbase_segmentation?tab=readme-ov-file

image

 

image

 

自定义提示

模型会根据不同的提示文本,生成不同的掩码,可修改main.py中的category_cfg变量,自定义提示文本。

  • landcover_prompts 为地物分类的提示,在全景图中场景下一般用于分割区域连续或新增的类别
  • cityobject_prompts 作为实例分割的提示,在全景图中场景下一般用于图像内区域不连续的对象类别
  • landcover_prompts_cn和cityobject_prompts_cn为每个类别的中文含义
category_cfg = {
    "landcover_prompts": ['building', 'low vegetation', 'tree', 'river', 'shed', 'road', 'lake', 'bare soil'],
    "landcover_prompts_cn": ['建筑', '低矮植被', '树木', '河流', '棚屋', '道路', '湖泊', '裸土'],
    "cityobject_prompts": ['car', 'truck', 'bus', 'train', 'ship', 'boat'],
    "cityobject_prompts_cn": ['轿车', '卡车', '巴士', '列车', '船(舰)', '船(舶)']
}

  

 

from inference import PSAM

# 模型配置文件和权重文件
model_cfg = {
    "DINO_WEIGHT_PATH": "weights/GSA_weights/groundingdino_swinb_cogcoor.pth",
    "DINO_CFG_PATH": "groundingdino/config/GroundingDINO_SwinB.py",
    "SAM_WEIGHT_PATH": "weights/GSA_weights/sam_vit_h_4b8939.pth",
    "CLIP_WEIGHT_DIR": "weights/CLIP_weights/"
}

# prompts提示,可自定义类别列表
# 模型会根据不同的prompts提示,生成不同的掩码
# category_cfg = {
#     "landcover_prompts": ['building', 'low vegetation', 'tree', 'water', 'shed', 'road', 'lake', 'bare soil',],
#     "landcover_prompts_cn": ['建筑', '低矮植被', '树木', '水体', '棚屋', '道路', '湖泊', '裸土'],
#     "cityobject_prompts": ['car', 'truck', 'bus', 'train', 'ship', 'boat'],
#     "cityobject_prompts_cn": ['轿车', '卡车', '巴士', '列车', '船(舰)', '船(舶)']
# }
category_cfg = {
    "landcover_prompts": [ 'building', 'water', 'tree', 'road','shed', 'cropland','grassland', 'Agricultural Fields','bare soil'],
    "landcover_prompts_cn": ['建筑', '水体', '树木', '道路', '棚屋', '农田', '草地', '农用地','裸土'],
    "cityobject_prompts": ['car', 'truck','train'],
    "cityobject_prompts_cn": ['轿车', '货车','火车']
}

gpus = ["1"]

# matplotlib使用中文绘制
cn_style = False # 是否使用中文
font_style_path = '/usr/share/fonts/wqy-microhei/wqy-microhei.ttc' # 中文字体路径,可通过fc-list命令查看系统中所安装的字体

if __name__ == "__main__":
    psam = PSAM(model_cfg, category_cfg, gpus)

    # img_path = "/home/piesat/data/无人机全景图/panorama01-04/match_imgs/CD_dataset/01->03/A_B/A/100_right_0_1_hw(2701,672).png"
    # img_path = "/home/piesat/media/ljh/pycharm_project__ljh/panorama_sam/photos/c1.png"
    file_path = '/home/piesat/data/无人机全景图/panorama01-04/match_imgs/CD_dataset/cwptys_tmp/A'
    save_path = '/home/piesat/media/ljh/pycharm_project__ljh/panorama_sam/photos/croplands/'

    import os
    files = []
    for root, dirs, filenames in os.walk(file_path):
        for filename in filenames:
            in_img_path = os.path.join(root, filename)
            out_img_path = os.path.join(save_path, filename)
            psam.load_image(in_img_path)
            panoptic_inds = psam.generate_panoptic_mask()
            psam.plt_draw_image(cn_style, font_style_path, out_img_path)
            print(panoptic_inds.shape) # panoptic_inds:单通道掩码图像

  

 

import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

import torch

from utils.data_utils import generate_color_list
from utils.load_models import load_clip_model, load_dino_model, load_sam_model
from utils.func_utils import dino_detection, sam_masks_from_dino_boxes, clipseg_segmentation, \
    clip_and_shrink_preds, sample_points_based_on_preds, sam_mask_from_points, preds_to_semantic_inds


class PSAM(object):
    def __init__(self, model_cfg, category_cfg, gpu_ids):
        # 初始化GroundingDINO、SAM、CLIPSeg模型
        self.device = torch.device("cuda:%s" % gpu_ids[0] if torch.cuda.is_available() and len(gpu_ids) > 0 else "cpu")
        self.groundingdino_model = load_dino_model(model_cfg["DINO_CFG_PATH"], model_cfg["DINO_WEIGHT_PATH"], self.device)
        self.sam_predictor = load_sam_model(model_cfg["SAM_WEIGHT_PATH"], self.device)
        self.clipseg_processor, self.clipseg_model = load_clip_model(model_cfg["CLIP_WEIGHT_DIR"], self.device)

        self.landcover_categories = category_cfg["landcover_prompts"]
        self.cityobject_categories = category_cfg["cityobject_prompts"]
        self.category_names = ["background"] + self.landcover_categories + self.cityobject_categories
        self.category_name_to_id = {
            category_name: i for i, category_name in enumerate(self.category_names)
        }
        self.category_id_to_name = {
            i: category_name for i, category_name in enumerate(self.category_names)
        }
        self.color_map = generate_color_list(len(self.category_names))

        self.landcover_categories_cn = category_cfg["landcover_prompts_cn"]
        self.cityobject_categories_cn = category_cfg["cityobject_prompts_cn"]
        self.category_names_cn = ["背景"] + self.landcover_categories_cn + self.cityobject_categories_cn
        self.category_id_to_name_cn = {
            i: category_name for i, category_name in enumerate(self.category_names_cn)
        }


    def load_image(self, img_path):
        # 读取图像并进行SAM的图像编码
        image = Image.open(img_path)
        self.image = image.convert("RGB")
        self.image_array = np.asarray(self.image)
        self.sam_predictor.set_image(self.image_array)


    def generate_panoptic_mask(self, dino_box_threshold=0.2,
                               dino_text_threshold=0.20,
                               segmentation_background_threshold=0.1,
                               shrink_kernel_size=10,
                               num_samples_factor=300
    ):
        # 1.基于DINO的城市目标检测,并结合SAM进行分割
        cityobject_category_ids = []
        cityobject_masks = torch.empty(0)
        cityobject_boxes = []
        if len(self.cityobject_categories) > 0:
            cityobject_boxes, cityobject_category_ids, _ = dino_detection(
                self.groundingdino_model,
                self.image,
                self.cityobject_categories,
                self.category_name_to_id,
                dino_box_threshold,
                dino_text_threshold,
                self.device,
            )
            if len(cityobject_boxes) > 0:
                cityobject_masks = sam_masks_from_dino_boxes(
                    self.sam_predictor, self.image_array, cityobject_boxes, self.device
                )

        # 2.基于CLIP的地物分类,并结合SAM进行分割
        if len(self.landcover_categories) > 0:
            clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
                self.clipseg_processor,
                self.clipseg_model,
                self.image,
                self.landcover_categories,
                segmentation_background_threshold,
                self.device,
            )
            clipseg_semantic_inds_without_cityobject = clipseg_semantic_inds.clone()
            if len(cityobject_boxes) > 0:
                combined_cityobject_mask = torch.any(cityobject_masks, dim=0)
                clipseg_semantic_inds_without_cityobject[combined_cityobject_mask[0]] = 0
            clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
                clipseg_semantic_inds_without_cityobject,
                clipseg_preds,
                shrink_kernel_size,
                len(self.landcover_categories) + 1,
            )
            sam_preds = torch.zeros_like(clipsed_clipped_preds)
            for i in range(clipsed_clipped_preds.shape[0]):
                clipseg_pred = clipsed_clipped_preds[i]
                num_samples = int(relative_sizes[i] * num_samples_factor)
                if num_samples == 0:
                    continue
                points = sample_points_based_on_preds(
                    clipseg_pred.cpu().numpy(), num_samples
                )
                if len(points) == 0:
                    continue
                pred = sam_mask_from_points(self.sam_predictor, self.image_array, points)

                sam_preds[i] = pred

            sam_semantic_inds = preds_to_semantic_inds(
                sam_preds, segmentation_background_threshold
            )

        # 3.结合城市目标和地物分类的掩码结果
        if len(self.landcover_categories) > 0:
            # 进行开闭运算
            self.panoptic_inds = sam_semantic_inds.clone().cpu().numpy().astype(np.uint8)
            kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))
            self.panoptic_inds = cv2.morphologyEx(self.panoptic_inds, cv2.MORPH_OPEN, kernel)
            self.panoptic_inds = cv2.morphologyEx(self.panoptic_inds, cv2.MORPH_CLOSE, kernel)
        else:
            self.panoptic_inds = np.zeros((self.image_array.shape[0], self.image_array.shape[1]), dtype=np.uint8)

        for mask_cid in range(cityobject_masks.shape[0]):
            ind = cityobject_category_ids[mask_cid]
            mask_bool = cityobject_masks[mask_cid].squeeze(dim=0).cpu().numpy()
            self.panoptic_inds[mask_bool] = ind

        return self.panoptic_inds

    def plt_draw_image(self, cn_style=False, font_style_path=None, save_file_path =None):
        # 是否使用中文显示
        if cn_style==True and font_style_path is not None:
            cn_style = True
            font = FontProperties(fname=font_style_path)
            id_to_name = self.category_id_to_name_cn
        else:
            cn_style = False
            font = FontProperties()
            id_to_name = self.category_id_to_name

        # 使用unique函数和return_counts参数计算每种类别的占用像素个数
        unique_values, counts = np.unique(self.panoptic_inds, return_counts=True)
        count_map = {}
        bar_colors = []  # 储存每种类别的颜色
        for value, count in zip(unique_values, counts):
            count_map[id_to_name[value]] = count
            r, g, b = self.color_map[value]
            r = r / 255
            g = g / 255
            b = b / 255
            bar_colors.append((r, g, b, 1.0))
        x = list(count_map.keys())
        y = list(count_map.values())

        # 创建子图
        fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))

        # 绘制原图
        axes[0, 0].imshow(self.image)

        # 绘制掩码图
        cm = [list(t) for t in self.color_map]
        cm = np.array(cm).astype('uint8')
        label_img = cm[self.panoptic_inds]
        axes[0, 1].imshow(Image.fromarray(label_img))

        # 绘制合并图
        draw_image = cv2.addWeighted(np.array(self.image), 0.7, label_img, 0.3, 0)
        axes[1, 0].imshow(Image.fromarray(draw_image))

        # 绘制柱状图
        axes[1, 1].bar(range(len(x)), y, label=x, color=bar_colors)
        # 添加数值标签
        for a, b in zip(range(len(x)), y):
            axes[1, 1].text(a, b, b, ha='center', va='bottom', fontproperties=font)
        # 添加标题和横纵坐标含义
        if cn_style:
            axes[1, 1].set_title('统计每个类别占用的像素', fontproperties=font)
            axes[1, 1].set_xlabel('类别', fontproperties=font)
            axes[1, 1].set_ylabel('像素', fontproperties=font)
        else:
            axes[1, 1].set_title('Pixel Count', fontproperties=font)
            axes[1, 1].set_xlabel('Category', fontproperties=font)
            axes[1, 1].set_ylabel('Pixel', fontproperties=font)
        axes[1, 1].set_xticklabels([])
        # 添加图例
        axes[1, 1].legend(prop=font)

        # 调整子图间距
        plt.subplots_adjust(wspace=0.15, hspace=0.2)
        #保存图像
        plt.savefig(save_file_path)
        # # 显示图形
        # plt.show()

  

posted on 2025-10-22 06:24  MKT-porter  阅读(3)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3