生成对抗样本之 FGSM(Fast Gradient Sign Method)快速梯度符号方法,MNIST上,Pytorch实现
关于FGSM梯度下降用于生成对抗样本
1. 数学原理
在生成对抗样本领域,有一个简单的方法是FGSM(快速梯度符号),它的核心思想是:在模型对图像进行分类的时候,对输入的图片加一个绝对值为 的扰动,这个扰动的符号取决于模型的损失函数对输入图像X的梯度符号,如下:

- 先求模型的输出f(x)与标签y的损失对输入图像X的梯度
- 损失loss对输入X的梯度方向,是我们施加攻击的方向,这里取梯度的方向,也就是让loss变大,模型的效果变差
- 每一次取损失的符号sign(loss)乘上我们认为设置的扰动,附加到原图像X上,这个就作为我们的对抗样本
- 优点:实施攻击的时候,可以一次性对所有的图像进行处理,模型预测一次,计算一次损失,计算一次梯度,附加噪声,生成对抗样本,所以这也是一种实现最快的生成对抗样本手段
- 缺点:扰动的大小是认为设置的,小了,模型能够过滤该噪声,攻击失效;大了,肉眼可见,攻击的隐蔽性就不足了。
总结,FGSM就是利用损失函数对输入的梯度,计算一次梯度,加一次扰动在原图像上,让模型对当前样本的预测变得更“错”——往让它犯错的方向推一点点。
2. 演示伪代码
x.requires_grad = True # 让输入参与梯度计算
output = model(x) # 正常前向传播
loss = loss_fn(output, y) # 计算损失
loss.backward() # 反向传播,得到 ∇xL
x_adv = x + ε * sign(∇xL) # 在输入上添加扰动 → 攻击样本
3. Pytorch在MNIST上实施FGSM攻击代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# ----------------------------
# 1. 准备 MNIST 数据
# ----------------------------
transform = transforms.Compose([
transforms.ToTensor()
])
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('.', train=False, download=True, transform=transform),
batch_size=1, shuffle=True
)
# ----------------------------
# 2. 定义一个简单 CNN 模型
# ----------------------------
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv = nn.Conv2d(1, 10, kernel_size=5)
self.fc = nn.Linear(10 * 24 * 24, 10) # (28-5+1)^2 = 24x24
def forward(self, x):
x = F.relu(self.conv(x))
x = x.view(-1, 10 * 24 * 24)
return self.fc(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
# 训练模型(为了快速示例我们用少量 epoch)
def train(model, epochs=1):
optimizer = torch.optim.Adam(model.parameters())
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('.', train=True, download=True, transform=transform),
batch_size=64, shuffle=True
)
model.train()
for epoch in range(epochs):
for x, y in train_loader:
x, y = x.to(device), y.to(device)
output = model(x)
loss = F.cross_entropy(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train(model, epochs=1)
model.eval()
# ----------------------------
# 3. FGSM 攻击函数
# ----------------------------
def fgsm_attack(model, loss_fn, x, y, epsilon):
x_adv = x.clone().detach().to(device).requires_grad_(True)
y = y.to(device)
output = model(x_adv)
loss = loss_fn(output, y)
model.zero_grad()
loss.backward()
grad_sign = x_adv.grad.data.sign()
x_adv = x_adv + epsilon * grad_sign
x_adv = torch.clamp(x_adv, 0, 1)
return x_adv.detach()
# ----------------------------
# 4. 攻击 + 可视化
# ----------------------------
loss_fn = nn.CrossEntropyLoss()
epsilon = 0.25 # 你可以尝试 0.1 ~ 0.3
# 随机选一张图像
for x, y in test_loader:
x, y = x.to(device), y.to(device)
pred_orig = model(x).argmax(dim=1).item()
x_adv = fgsm_attack(model, loss_fn, x, y, epsilon)
pred_adv = model(x_adv).argmax(dim=1).item()
break
# ----------------------------
# 5. 显示图像对比
# ----------------------------
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.title(f"Original (pred: {pred_orig})")
plt.imshow(x.squeeze().cpu().numpy(), cmap='gray')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title(f"Adversarial (pred: {pred_adv})")
plt.imshow(x_adv.squeeze().cpu().numpy(), cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()

浙公网安备 33010602011771号