对3D图像进行裁剪

在对医学图像进行深度学习的过程中,我们会遇到图片过大,导致train的过程中网络会瘫痪,所以我们会考虑到对图像进行分割。比如一张155x240x240的图像,我们可以将他分割成一系列128x128x128大小的小图像。代码如下:

from crop import *
import torch
def split_image(img, crop_size=(128,128,128)):
    patient_image = img  # (1, 240, 240, 155, 4)
    patient_image = patient_image[0, ...]  # (240, 240, 155, 4)
    patient_image = patient_image.permute(3, 0, 1, 2)  # (1, 4, 155, 240, 240)
    patient_image = patient_image.cpu().numpy()
    pasient_image = crop_pad(patient_image, crop_size)
    patient_image = torch.from_numpy(pasient_image).permute(1, 0, 2, 3, 4)  # (C, S, T, Y, W)
    patient_image = patient_image.unsqueeze(0).numpy()
    # print("patient_image", patient_image.shape)
    return patient_image
if __name__ =="__main__":
    device = torch.device('cuda')
    image_size = 128
    image = torch.rand((1, 155, 240, 240, 4), device=device)
    x = split_image(image, (128,128,128))
    print("before Pinjie",image.shape)
    print("after Pinjie",x.shape)

结果显示:

我们可以看到,我们将图像分割成8个128x128x128大小的图像。

详细代码,请看我的github

github
GitHub.
posted @ 2021-07-16 22:44  九叶草  阅读(557)  评论(0编辑  收藏  举报