Pytorch 医疗图像分割里的Dice

图像分割里的Dice看了下也是蛮多的,其中最常用的吧就是:
两张图片的交乘以2除以他们的和
这个写的特别好医学图像分割之 Dice Loss,看这个就够了,下面我自己记录下
在做完交叉熵后,由于出来的是每个像素的类别预测概率,得把这些个概率转为相对应的像素才行。

# lesion_pre为网络跑出来结果,假如它的shape是[4, 2, 512, 512],这里做的是2D的图像,
# torch送网络的结构就是批次、纬度、高、宽。我们需要将第二个维度转换为0和1,看那个预
# 测概率大就取哪个,这样0和1就可以赋值了
lesion_mm = torch.max(lesion_pre, dim=1).indices

接下来就计算Dice,计算的时候需要转换到cpu和numpy格式,如果之前是在cuda上面跑的话

def Dice(inp, target, eps=1):
	# 抹平了,弄成一维的
    input_flatten = inp.flatten()
    target_flatten = target.flatten()
    # 计算交集中的数量
    overlap = np.sum(input_flatten * target_flatten)
    # 返回值,让值在0和1之间波动
    return np.clip(((2. * overlap) / (np.sum(target_flatten) + np.sum(input_flatten) + eps)), 1e-4, 0.9999)

就酱

posted @ 2019-11-07 09:29  赫凯  阅读(328)  评论(0)    收藏  举报