nn.parameter
torch.nn.parameter.Parameter 是 PyTorch 中的一个类,它是 Tensor 的一个子类。特别之处在于,当你将其作为属性赋值给 Module 时,它会被自动添加到该模块的参数列表中,并可以通过如 parameters() 迭代器访问。这一特性使得管理模型参数变得更为方便和直观,尤其是对于深度学习模型而言。
主要特点
-
被视为模块参数的张量:
Parameter类允许你定义那些应当被视为模型参数的张量。这意味着它们会自动被包含在模型的参数集合中,这对于训练过程中的参数更新非常重要。 -
自动注册为模块参数:当一个
Parameter被赋值给一个Module类的属性时,它会自动成为该模块的一部分,参与后续的优化过程。
参数说明
-
data (Tensor):这是实际的参数张量,表示模型参数的具体数值。
-
requires_grad (bool, optional):这个布尔值指定了是否需要对该参数计算梯度。这对于确定哪些参数会在反向传播过程中得到更新至关重要。默认情况下,它的值是
True,意味着默认会对该参数进行梯度计算。需要注意的是,即使在torch.no_grad()上下文中,Parameter创建时的requires_grad默认值仍然为True,这表明了参数的梯度计算需求不会因为上下文的变化而自动改变。
通过使用 Parameter,PyTorch 提供了一种有效的方式来管理和优化深度学习模型中的参数,使得开发者可以更专注于模型设计和实验,而不必过多地考虑底层的细节。这种机制有助于简化代码逻辑,提高开发效率。

浙公网安备 33010602011771号