class Classifier(nn.Module):
def __init__(self, in_size, in_ch):
super(Classifier, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_ch, 3, 3, 1, 1),
nn.ReLU(),
)
self.layer2 = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(6, 3, 3, 1, 1),
nn.ReLU(),
)
self.fc = nn.Linear(3 * in_size * in_size, 1)
def forward(self, x):
x = self.layer1(x)
identity = x
x = self.layer2(x)
x += identity
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def print_grad(grad):
print('========= register_hook output:======== ')
print(grad.size())
print(grad)
def grad_hook(md, grad_in, grad_out):
print('========= register_backward_hook output:======== ')
# grad_in 包含: grad_bias, grad_x, grad_w 三者的梯度: (delta_bias, delta_x, delta_w)
# grad_out 是md整体的梯度,也等于grad_bias
print(grad_out[0].size())
print(grad_out[0])
torch.random.manual_seed(1000)
if __name__ == '__main__':
in_size, in_ch = 4, 1
x = torch.randn(1, 1, 4, 4)
model = Classifier(in_size, in_ch)
y_hat = model(x)
y_gt = torch.Tensor([[1.5]])
crt = nn.MSELoss()
print(y_hat)
print('=======================')
identity = []
for idx, (name, md) in enumerate(model._modules.items()):
md.register_backward_hook(grad_hook)
if isinstance(md, nn.Linear):
x += identity[0]
x = torch.flatten(x, 1)
x = md(x)
x.register_hook(print_grad)
if idx == 0:
identity.append(x)
loss = crt(x, y_gt)
loss.backward()
print(x)