U2-Net 预测函数

包含单个图片检测以及视频检测

import os
import time

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import subprocess
from torchvision.transforms import transforms

from src import u2net_full
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

def time_synchronized():
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    return time.time()

# 获取GPU相关信息
def get_gpu_info():
    try:
        cmd_out = subprocess.check_output('nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv,noheader',
                                          shell=True)
        gpu_info = cmd_out.decode().strip().split('\n')
        gpu_info = [info.split(', ') for info in gpu_info]
        return gpu_info
    except subprocess.CalledProcessError as e:
        print("Error while invoking nvidia-smi: ", e)
        return None

# 打印 GPU 型号及占用情况
def print_gpu_usage():
    gpu_info = get_gpu_info()
    if gpu_info:
        total_memory = 0
        used_memory = 0
        for name, used, total in gpu_info:
            used_memory += int(used.strip().split()[0])
            total_memory += int(total.strip().split()[0])
            memory_usage_percent = round(used_memory / total_memory * 100, 2)
            print(f"GPU: {name.strip()}, Memory used: {used.strip()}, Memory total: {total.strip()}"
                  f", Memory usage: {memory_usage_percent}%")

# 将原图像与分割后的图像混合
def Image_Blend(src, res):
    info = res.shape
    height = info[0]
    width = info[1]
    dst = np.zeros((height, width, 3), np.uint8)
    # 分割后的图换色
    mask = ~(res == [0, 0, 0]).all(axis=2)
    res[mask] = [0, 0, 255]
    dst = res
    # 2.图像混合
    img = cv2.addWeighted(src, 0.8, dst, 0.2, 0, dtype=cv2.CV_8UC3)
    return img

# 读取 Gpu 信息
def gpu_info() -> str:
    info = ''
    for id in range(torch.cuda.device_count()):
        p = torch.cuda.get_device_properties(id)
        info += f'CUDA:{id} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n'
    return info[:-1]

# 单单检测图片
def pic_predict(threshold, device, data_transform, origin_img, model):
    h, w = origin_img.shape[:2]
    img = data_transform(origin_img)
    img = torch.unsqueeze(img, 0).to(device)  # [C, H, W] -> [1, C, H, W]

    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        # 推理
        pred = model(img)
        pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
        pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
        pred_mask = np.where(pred > threshold, 1, 0)
        origin_img = np.array(origin_img, dtype=np.uint8)
        seg_img = origin_img * pred_mask[..., None]

    img_res = Image_Blend(origin_img,seg_img)
    cv2.imwrite("result/pred_result11.png", cv2.cvtColor(img_res.astype(np.uint8), cv2.COLOR_RGB2BGR))

# 视频检测
def video_pre(threshold, device, data_transform, origin_img, model):
    h, w = origin_img.shape[:2]
    img = data_transform(origin_img)
    img = torch.unsqueeze(img, 0).to(device)  # [C, H, W] -> [1, C, H, W]

    with torch.no_grad():
        # init model
        img_height, img_width = img.shape[-2:]
        init_img = torch.zeros((1, 3, img_height, img_width), device=device)
        model(init_img)

        # 推理
        pred = model(img)

        # 打印GPU占用信息
        print_gpu_usage()

        pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
        pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
        pred_mask = np.where(pred > threshold, 1, 0)
        origin_img = np.array(origin_img, dtype=np.uint8)
        seg_img = origin_img * pred_mask[..., None]

    img_res = Image_Blend(origin_img, seg_img)
    return img_res

def main():
    weights_path = "model_best.pth"
    img_path = "test/video.mp4"
    threshold = 0.5

    # 判断图片路径是否正确
    assert os.path.exists(img_path), f"image file {img_path} dose not exists."
    # 判断 Gpu 是否可用
    if torch.cuda.is_available():
        print(gpu_info())
    # 设置硬件 根据Gpu 是否可用,来选择用GPU 还是 CPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(320),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # 载入模型
    model = u2net_full()
    weights = torch.load(weights_path, map_location='cpu')
    if "model" in weights:
        model.load_state_dict(weights["model"])
    else:
        model.load_state_dict(weights)
    model.to(device)
    model.eval()

    str = os.path.splitext(img_path)[-1]
    if str == ".mp4":
        print("视频读入")
        # 获取视频部分
        cap = cv2.VideoCapture(img_path)
        i = 0
        # 2、获取图像的属性(宽和高),并将其转化为整数
        frame_width = int(cap.get(3))
        frame_height = int(cap.get(4))
        # 3、创建保存视频的对象,设置编码格式、帧率、图像的宽高等
        out = cv2.VideoWriter('result/OutPut2.avi', cv2.VideoWriter_fourcc(*'FFV1'), 30,
                              (frame_width, frame_height))

        while (cap.isOpened()):
            # 4、获取每一帧图像
            ret, frame = cap.read()
            img = frame
            i += 1
            start_time = time.time()  # 开始处理一帧图片的时间
            img_res = video_pre(threshold, device, data_transform, img, model)
            # 5、将每一帧图像写入到输出文件中
            if ret == True:
                out.write(img_res)
            else:
                break
            end_time = time.time()
            cost_time = end_time - start_time
            print("检测第 {} 帧花了 {:.8f}s 。".format(i, cost_time))
        cap.release()
        out.release()
        cv2.destroyAllWindows()
    elif str == ".jpg":
        print("图片读入")
        start_time = time.time()  # 开始处理一帧图片的时间
        origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
        pic_predict(threshold, device, data_transform, origin_img, model)
        end_time = time.time()
        cost_time = end_time - start_time
        print("检测一张图片花了 {:.8f}s 。".format(cost_time))
    else:
        print("请重新读入图片或者视频")

if __name__ == '__main__':
    main()
posted @ 2023-04-28 13:17  ~逍遥子~  阅读(17)  评论(0编辑  收藏  举报