Pytorch中返回super().forward()
https://github.com/pytorch/pytorch/issues/42885
import torch
import torch.nn as nn
class Foo(nn.Conv1d):
def forward(self, x):
return super().forward(x)
这里return super.forward(x)怎么理解?
返回父类中的forward()方法。
参考:https://stackoverflow.com/questions/54752983/calling-supers-forward-method
import torch
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
return super(Child, self).forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still
def increment_by_one(module, input, output):
return output + 1
class Parent(torch.nn.Module):
def forward(self, tensor):
return tensor + 1
class Child(Parent):
def forward(self, tensor):
# Increment by `1` from Parent
super().register_forward_hook(increment_by_one)
return super().forward(tensor) + 1
module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1))) # and it is 5 indeed
print(module.forward(torch.tensor(1))) # here is 3
例如DenseNet中出现类似:
定义DenseLayer(这里似乎仅仅定义了网络层,而forward行为则是直接返回super().forward(x))
class DenseLayer(nn.Sequential):
def __init__(self, in_channels, growth_rate):
super().__init__()
self.add_module('norm', nn.BatchNorm1d(in_channels))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv1d(in_channels, growth_rate, kernel_size=3,
stride=1, padding=1, bias=False))
self.add_module('drop', nn.Dropout1d(p=0.2))
def forward(self, x):
return super().forward(x)
通过DenseLayer组装DenseBlock:
class DenseBlock(nn.Module):
def __init__(self, in_channels, growth_rate, n_layers):
super().__init__()
self.layers = nn.ModuleList([DenseLayer(in_channels + i*growth_rate, growth_rate) for i in range(n_layers)])
def forward(self, x):
for layer in self.layers:
out = layer(x)
x = torch.cat([x, out], 1) # 1 = channel axis
return x
快去成为你想要的样子!
浙公网安备 33010602011771号