import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.cov1 = nn.Conv2d(1,3,3)
self.cov2 = nn.Conv2d(3,2,3)
def backword(self , x):
print("model backword fisrt",x.shape)
x = F.relu(self.cov1(x))
x = F.relu(self.cov2(x))
print("model backword end",x.shape)
return x
def before_hook(model,input):
print("brefore hook",model," input ",input[0].shape)
return torch.zeros(1, 1, 7, 7)
model = Model()
hook = model.register_forward_pre_hook(before_hook)
input = torch.zeros(1,1,5,5)
model(input)
hook.remove()