def weight_init(m): # 初始化权重
# print(m)
if isinstance(m, torch.nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127)
# m.bias.data.zero_()
if m.bias!=None:
m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127)
elif isinstance(m, torch.nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# data = np.load("weight.npy")
# m.weight.data = torch.tensor(data)
m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127)
# print("weight",m.weight.data.shape)
# print(m.weight.data)
# print(m.weight.data)
# m=torch.nn.Conv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, bias=True, stride=m.stride, padding=m.padding)
if m.bias!=None:
m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127)
elif isinstance(m, torch.nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, torch.nn.Linear):
m.weight.data=torch.randint_like(m.weight.data, low=-128, high=127)
if m.bias is not None:
m.bias.data.zero_()
# 将模型权重初始化为int8
model.apply(weight_init)