nn.Linear 的局限性
nn.Linear(1, 1) 是 PyTorch 中最基础的线性层,表示输入维度为 1、输出维度也为 1,对应数学公式 y = wx + b。如果想实现更复杂的函数(如包含 x²、e^x、log(x) 等非线性项),需要手动组合线性层和激活函数。一、理解 nn.Linear 的局限性
nn.Linear(in_features, out_features) 只能表示线性变换,即:- 输入
x是标量时,输出y = wx + b(w和b是可学习参数) - 输入
x是向量时,输出y = Wx + b(W是权重矩阵,b是偏置向量)
无法直接表示非线性函数(如
x²、e^x、log(x))。二、实现复杂数学函数的两种方法
方法 1:自定义前向传播(推荐)
手动在
forward 方法中组合非线性项。例如,实现 y = w₁x² + w₂e^x + w₃log(x) + b:python
运行
import torch
import torch.nn as nn
class NonlinearModel(nn.Module):
def __init__(self):
super(NonlinearModel, self).__init__()
# 为每个非线性项设置独立的权重和一个共享偏置
self.weight1 = nn.Parameter(torch.randn(1)) # w₁
self.weight2 = nn.Parameter(torch.randn(1)) # w₂
self.weight3 = nn.Parameter(torch.randn(1)) # w₃
self.bias = nn.Parameter(torch.randn(1)) # b
def forward(self, x):
# 实现 y = w₁x² + w₂e^x + w₃log(x) + b
term1 = self.weight1 * (x ** 2) # x² 项
term2 = self.weight2 * torch.exp(x) # e^x 项
term3 = self.weight3 * torch.log(x) # log(x) 项
return term1 + term2 + term3 + self.bias
# 测试模型
model = NonlinearModel()
x = torch.tensor([2.0]) # 测试输入 x=2
print(f"输入 x={x.item()}")
print(f"输出 y={model(x).item()}")
方法 2:用线性层 + 激活函数近似(神经网络常用)
通过组合多个线性层和非线性激活函数(如
ReLU、Sigmoid),可以近似任意复杂函数(Universal Approximation Theorem)。例如:python
运行
import torch
import torch.nn as nn
class ApproximateModel(nn.Module):
def __init__(self):
super(ApproximateModel, self).__init__()
# 隐藏层:1个输入 → 10个神经元
self.hidden = nn.Linear(1, 10)
# 输出层:10个神经元 → 1个输出
self.output = nn.Linear(10, 1)
# 非线性激活函数(如ReLU)
self.relu = nn.ReLU()
def forward(self, x):
# 第一层:线性变换 + 非线性激活
hidden_output = self.relu(self.hidden(x))
# 第二层:线性变换(输出层通常不接激活函数)
return self.output(hidden_output)
# 测试模型
model = ApproximateModel()
x = torch.tensor([2.0])
print(f"输入 x={x.item()}")
print(f"输出 y={model(x).item()}")
三、两种方法对比
| 方法 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| 自定义前向传播 | 精确实现数学公式,参数明确 | 需要手动设计每个项 | 已知数学表达式的场景 |
| 神经网络近似 | 自动学习复杂函数关系 | 参数多,解释性差 | 函数形式未知,依赖数据学习 |
四、常见非线性函数的 PyTorch 实现
| 数学表达式 | PyTorch 代码 | 注意事项 |
|---|---|---|
| \(x^2\) | x ** 2 |
|
| \(e^x\) | torch.exp(x) |
x 需为正数 |
| \(\log(x)\) | torch.log(x) |
x 需为正数 |
| \(\sin(x)\) | torch.sin(x) |
|
| \(\text{ReLU}(x)\) | torch.relu(x) 或 nn.ReLU() |
最常用的激活函数:\(\max(0, x)\) |
五、练习:实现带根号和指数的函数
假设要实现 \(y = w_1\sqrt{x} + w_2e^{-x} + b\),可以这样写:
python
运行
class CustomFunction(nn.Module):
def __init__(self):
super(CustomFunction, self).__init__()
self.weight1 = nn.Parameter(torch.randn(1))
self.weight2 = nn.Parameter(torch.randn(1))
self.bias = nn.Parameter(torch.randn(1))
def forward(self, x):
# 确保输入为正数(避免根号和对数错误)
x = torch.clamp(x, min=1e-6) # 限制最小值为1e-6
term1 = self.weight1 * torch.sqrt(x) # √x 项
term2 = self.weight2 * torch.exp(-x) # e^(-x) 项
return term1 + term2 + self.bias
总结
- 线性层
nn.Linear只能表示线性变换,无法直接实现非线性函数。 - 自定义前向传播:通过手动组合数学运算(如
x**2、torch.exp(x))实现精确的非线性公式。 - 神经网络近似:通过堆叠线性层和激活函数,自动学习复杂函数关系(更适合数据驱动的场景)。

浙公网安备 33010602011771号