深度神经网络 —— 使用深度自动编码器进行手写数字的去噪音
代码:
import torch
import platform
print("PyTorch version:{}".format(torch.__version__))
print("Python version:{}".format(platform.python_version()))
import torchvision
from torchvision import datasets, transforms
#from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5],std=[0.5])])
dataset_train = datasets.MNIST(root = "./data",
transform = transform,
train = True,
download = True)
dataset_test = datasets.MNIST(root = "./data",
transform = transform,
train = False)
train_load = torch.utils.data.DataLoader(dataset = dataset_train,
batch_size = 64,
shuffle = True)
test_load = torch.utils.data.DataLoader(dataset = dataset_test,
batch_size = 64,
shuffle = True)
def train_show_images():
images, label = next(iter(train_load))
print(images.shape)
images_example = torchvision.utils.make_grid(images)
images_example = images_example.numpy().transpose(1,2,0)
mean = [0.5]
std = [0.5]
images_example = images_example*std + mean
# plt.imshow(images_example)
# plt.show()
plt.imsave(f"train_raw.png", images_example)
print("ok")
noisy_images = images_example + 0.5*np.random.randn(*images_example.shape)
noisy_images = np.clip(noisy_images, 0., 1.)
# plt.imshow(noisy_images)
# plt.show()
plt.imsave(f"test_noisy.png", noisy_images)
print("ok")
train_show_images()
class AutoEncoder(torch.nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = torch.nn.Sequential(torch.nn.Conv2d(1,64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(64,128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2))
self.decoder = torch.nn.Sequential(torch.nn.Upsample(scale_factor=2, mode="nearest"),
torch.nn.Conv2d(128,64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Upsample(scale_factor=2, mode="nearest"),
torch.nn.Conv2d(64,1, kernel_size=3, stride=1, padding=1))
def forward(self, input):
output = self.encoder(input)
output = self.decoder(output)
return output
model = AutoEncoder()
#Use_gpu = torch.cuda.is_available()
# if Use_gpu:
# model = model.cuda()
print(device)
model.to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters())
loss_f = torch.nn.MSELoss()
epoch_n = 1
for epoch in range(epoch_n):
running_loss = 0.0
print("Epoch {}/{}".format(epoch, epoch_n))
print("-"*10)
for data in train_load:
X_train,_= data
noisy_X_train = X_train + 0.5*torch.randn(X_train.shape)
noisy_X_train = torch.clamp(noisy_X_train, 0., 1.)
#X_train, noisy_X_train = Variable(X_train.cuda()),Variable(noisy_X_train.cuda())
X_train, noisy_X_train = X_train.to(device),noisy_X_train.to(device)
train_pre = model(noisy_X_train)
loss = loss_f(train_pre, X_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss +=loss.data
print("Loss is:{:.4f}".format(running_loss/len(dataset_train)))
def test_show_images():
images, label = next(iter(test_load))
print(images.shape)
images_example = torchvision.utils.make_grid(images)
images_example = images_example.numpy().transpose(1,2,0)
mean = [0.5]
std = [0.5]
images_example = images_example*std + mean
# plt.imshow(images_example)
# plt.show()
plt.imsave(f"test_raw.png", images_example)
print("ok!")
noisy_images = images_example + 0.5*np.random.randn(*images_example.shape)
noisy_images = np.clip(noisy_images, 0., 1.)
# plt.imshow(noisy_images)
# plt.show()
plt.imsave(f"test_noisy.png", noisy_images)
print("ok!!")
noisy_X_train = images + 0.5*torch.randn(images.shape)
noisy_X_train = torch.clamp(noisy_X_train, 0., 1.)
#X_train, noisy_X_train = Variable(X_train.cuda()),Variable(noisy_X_train.cuda())
noisy_X_train = noisy_X_train.to(device)
train_pre = model(noisy_X_train)
images_example = torchvision.utils.make_grid(train_pre).cpu()
mean = 0.5
std = 0.5
images_example = images_example*std + mean
images_example = torch.clamp(images_example, 0., 1.)
images_example = images_example.numpy().transpose(1,2,0)
# plt.imshow(images_example)
# plt.show()
plt.imsave(f"test_denoise.png", images_example)
print("ok!!!")
test_show_images()
训练后的测试结果:
原始手写照片:

加入噪音后的手写照片:

去噪音后的手写照片:

PS:
去噪音操作后清晰度没有原始图片高,但是整体的去噪音效果还是不错的。
本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址,还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注处,如有侵权请与博主联系。
如果未特殊标注则为原创,遵循 CC 4.0 BY-SA 版权协议。
posted on 2025-10-25 22:16 Angry_Panda 阅读(1) 评论(0) 收藏 举报
浙公网安备 33010602011771号