运行下面的代码,分析结果,不做解释。
import numpy as np
def func(img, label):
#print('label[0]:', label[:,:,0])
if (1):
img = img / 255.
label = label[:, :, :, 0] if (len(label.shape) == 4) else label[:, :, 0]
new_label = np.zeros(label.shape + (3,))
for i in range(3):
new_label[label == i, i] = 1
print('\n i = \n ', i, '\n label = \n', label, '\n label==i: \n',label==i)
print('\n\n after-img:\n',img,'\n\n after-label:\n', new_label, '\n')
label = new_label
elif (np.max(img) > 1):
img = img / 255.
label = label / 255.
label[label > 0.5] = 1
label[label <= 0.5] = 0
return (img, label)
label = np.array([[[1,2,0],
[0,1,2],
[2,1,0]],
[[1,2,0],
[2,1,2],
[2,1,0]],
[[0,2,0],
[0,1,2],
[2,1,0]]])
img = np.array([[129, 255, 30],
[30, 30, 99],
[90, 123, 49]])
#print(len(label.shape))
print('\n\n before-img:\n',img,'\n\n before-label:\n', label, '\n')
func(img, label)