服务器之家:专注于服务器技术及软件下载分享
分类导航

PHP教程|ASP.NET教程|Java教程|ASP教程|编程技术|正则表达式|C/C++|IOS|C#|Swift|Android|VB|R语言|JavaScript|易语言|vb.net|

服务器之家 - 编程语言 - 编程技术 - 简单使用PyTorch搭建GAN模型

简单使用PyTorch搭建GAN模型

2021-08-25 22:53机器之心颂贤编译 编程技术

本文将带大家了解 GAN的工作原理 ,并介绍如何 通过PyTorch简单上手GAN 。

 简单使用PyTorch搭建GAN模型

以往人们普遍认为生成图像是不可能完成的任务,因为按照传统的机器学习思路,我们根本没有真值(ground truth)可以拿来检验生成的图像是否合格。

2014年,Goodfellow等人则提出生成 对抗网络(Generative Adversarial Network, GAN ,能够让我们完全依靠机器学习来生成极为逼真的图片。GAN的横空出世使得整个人工智能行业都为之震动,计算机视觉和图像生成领域发生了巨变。

本文将带大家了解 GAN的工作原理 ,并介绍如何 通过PyTorch简单上手GAN 。

GAN的原理

按照传统的方法,模型的预测结果可以直接与已有的真值进行比较。然而,我们却很难定义和衡量到底怎样才算作是“正确的”生成图像。

Goodfellow等人则提出了一个有趣的解决办法:我们可以先训练好一个分类工具,来自动区分生成图像和真实图像。这样一来,我们就可以用这个分类工具来训练一个生成网络,直到它能够输出完全以假乱真的图像,连分类工具自己都没有办法评判真假。

简单使用PyTorch搭建GAN模型

按照这一思路,我们便有了GAN:也就是一个 生成器(generator) 和一个 判别器(discriminator) 。生成器负责根据给定的数据集生成图像,判别器则负责区分图像是真是假。GAN的运作流程如上图所示。

损失函数

在GAN的运作流程中,我们可以发现一个明显的矛盾:同时优化生成器和判别器是很困难的。可以想象,这两个模型有着完全相反的目标:生成器想要尽可能伪造出真实的东西,而判别器则必须要识破生成器生成的图像。

为了说明这一点,我们设D(x)为判别器的输出,即x是真实图像的概率,并设G(z)为生成器的输出。判别器类似于一种二进制的分类器,所以其目标是使该函数的结果最大化:简单使用PyTorch搭建GAN模型这一函数本质上是非负的二元交叉熵损失函数。另一方面,生成器的目标是最小化判别器做出正确判断的机率,因此它的目标是使上述函数的结果最小化。

因此,最终的损失函数将会是两个分类器之间的极小极大博弈,表示如下:简单使用PyTorch搭建GAN模型理论上来说,博弈的最终结果将是让判别器判断成功的概率收敛到0.5。然而在实践中,极大极小博弈通常会导致网络不收敛,因此仔细调整模型训练的参数非常重要。

在训练GAN时,我们尤其要注意学习率等超参数,学习率比较小时能让GAN在输入噪音较多的情况下也能有较为统一的输出。

计算环境

本文将指导大家通过PyTorch搭建整个程序(包括torchvision)。同时,我们将会使用Matplotlib来让GAN的生成结果可视化。以下代码能够导入上述所有库:

  1. ""
  2. Import necessary libraries to create a generative adversarial network 
  3. The code is mainly developed using the PyTorch library 
  4. ""
  5. import time 
  6. import torch 
  7. import torch.nn as nn 
  8. import torch.optim as optim 
  9. from torch.utils.data import DataLoader 
  10. from torchvision import datasets 
  11. from torchvision.transforms import transforms 
  12. from model import discriminator, generator 
  13. import numpy as np 
  14. import matplotlib.pyplot as plt 

数据集

数据集对于训练GAN来说非常重要,尤其考虑到我们在GAN中处理的通常是非结构化数据(一般是图片、视频等),任意一class都可以有数据的分布。这种数据分布恰恰是GAN生成输出的基础。

为了更好地演示GAN的搭建流程,本文将带大家使用最简单的MNIST数据集,其中含有6万张手写阿拉伯数字的图片。

像 MNIST 这样高质量的非结构化数据集都可以在 格物钛 的 公开数据集 网站上找到。事实上,格物钛Open Datasets平台涵盖了很多优质的公开数据集,同时也可以实现 数据集托管及一站式搜索的功能 ,这对AI开发者来说,是相当实用的社区平台。

简单使用PyTorch搭建GAN模型

硬件需求

一般来说,虽然可以使用CPU来训练神经网络,但最佳选择其实是GPU,因为这样可以大幅提升训练速度。我们可以用下面的代码来测试自己的机器能否用GPU来训练:

  1. ""
  2. Determine if any GPUs are available 
  3. ""
  4. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'

实现

网络结构

由于数字是非常简单的信息,我们可以将判别器和生成器这两层结构都组建成全连接层(fully connected layers)。

我们可以用以下代码在PyTorch中搭建判别器和生成器:

  1. ""
  2. Network Architectures 
  3. The following are the discriminator and generator architectures 
  4. ""
  5.  
  6. class discriminator(nn.Module): 
  7.     def __init__(self): 
  8.         super(discriminator, self).__init__() 
  9.         self.fc1 = nn.Linear(784512
  10.         self.fc2 = nn.Linear(5121
  11.         self.activation = nn.LeakyReLU(0.1
  12.  
  13.     def forward(self, x): 
  14.         x = x.view(-1784
  15.         x = self.activation(self.fc1(x)) 
  16.         x = self.fc2(x) 
  17.         return nn.Sigmoid()(x) 
  18.  
  19.  
  20. class generator(nn.Module): 
  21.     def __init__(self): 
  22.         super(generator, self).__init__() 
  23.         self.fc1 = nn.Linear(1281024
  24.         self.fc2 = nn.Linear(10242048
  25.         self.fc3 = nn.Linear(2048784
  26.         self.activation = nn.ReLU() 
  27.  
  28. def forward(self, x): 
  29.     x = self.activation(self.fc1(x)) 
  30.     x = self.activation(self.fc2(x)) 
  31.     x = self.fc3(x) 
  32.     x = x.view(-112828
  33.     return nn.Tanh()(x) 

训练

在训练GAN的时候,我们需要一边优化判别器,一边改进生成器,因此每次迭代我们都需要同时优化两个互相矛盾的损失函数。

对于生成器,我们将输入一些随机噪音,让生成器来根据噪音的微小改变输出的图像:

  1. ""
  2. Network training procedure 
  3. Every step both the loss for disciminator and generator is updated 
  4. Discriminator aims to classify reals and fakes 
  5. Generator aims to generate images as realistic as possible 
  6. ""
  7. for epoch in range(epochs): 
  8.     for idx, (imgs, _) in enumerate(train_loader): 
  9.         idx += 1
  10.  
  11.         # Training the discriminator 
  12.         # Real inputs are actual images of the MNIST dataset 
  13.         # Fake inputs are from the generator 
  14.         # Real inputs should be classified as 1 and fake as 0
  15.         real_inputs = imgs.to(device) 
  16.         real_outputs = D(real_inputs) 
  17.         real_label = torch.ones(real_inputs.shape[0], 1).to(device) 
  18.  
  19.         noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
  20.         noise = noise.to(device) 
  21.         fake_inputs = G(noise) 
  22.         fake_outputs = D(fake_inputs) 
  23.         fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device) 
  24.  
  25.         outputs = torch.cat((real_outputs, fake_outputs), 0
  26.         targets = torch.cat((real_label, fake_label), 0
  27.  
  28.         D_loss = loss(outputs, targets) 
  29.         D_optimizer.zero_grad() 
  30.         D_loss.backward() 
  31.         D_optimizer.step() 
  32.  
  33.         # Training the generator 
  34.         # For generator, goal is to make the discriminator believe everything is 1
  35.         noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
  36.         noise = noise.to(device) 
  37.  
  38.         fake_inputs = G(noise) 
  39.         fake_outputs = D(fake_inputs) 
  40.         fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device) 
  41.         G_loss = loss(fake_outputs, fake_targets) 
  42.         G_optimizer.zero_grad() 
  43.         G_loss.backward() 
  44.         G_optimizer.step() 
  45.  
  46.         if idx % 100 == 0 or idx == len(train_loader): 
  47.             print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item())) 
  48.  
  49.     if (epoch+1) % 10 == 0
  50.         torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) 
  51.         print('Model saved.'

结果

经过100个训练时期之后,我们就可以对数据集进行可视化处理,直接看到模型从随机噪音生成的数字:

简单使用PyTorch搭建GAN模型

我们可以看到,生成的结果和真实的数据非常相像。考虑到我们在这里只是搭建了一个非常简单的模型,实际的应用效果会有非常大的上升空间。

不仅是有样学样

GAN和以往机器视觉专家提出的想法都不一样,而利用GAN进行的具体场景应用更是让许多人赞叹深度网络的无限潜力。下面我们来看一下两个最为出名的GAN延申应用。

CycleGAN

朱俊彦等人2017年发表的CycleGAN能够在没有配对图片的情况下将一张图片从X域直接转换到Y域,比如把马变成斑马、将热夏变成隆冬、把莫奈的画变成梵高的画等等。这些看似天方夜谭的转换CycleGAN都能轻松做到,并且结果非常准确。

简单使用PyTorch搭建GAN模型

GauGAN

英伟达则通过GAN让人们能够只需要寥寥数笔勾勒出自己的想法,便能得到一张极为逼真的真实场景图片。虽然这种应用需要的计算成本极为高昂,但是GauGAN凭借它的转换能力探索出了前所未有的研究和应用领域。

简单使用PyTorch搭建GAN模型

结语

相信看到这里,你已经知道了GAN的大致工作原理,并且能够自己动手简单搭建一个GAN了。

原文链接:https://www.jiqizhixin.com/articles/2021-08-11-4?utm_source=tuicool&utm_medium=referral

延伸 · 阅读

精彩推荐