Pytorch与视觉竞赛入门5-PyTorch搭建对抗生成网络

对抗生成网络原理

对抗生成网络通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。

生成模型和判别模型的区别

判别式模型求得P(Y|X),对未见示例X,根据P(Y|X)可以求得标记Y;

生成式模型求得P(Y,X),对于未见示例X,要求出X与不同标记之间的联合概率分布,然后大的获胜。

GAN组成

GAN有两个网络,G(Generator)和D(Discriminator)。G是一个生成图片的网络,它接收一个随机的噪声z,通过这个噪声生成图片,记做G(z);D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

在训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量把G生成的图片和真实的图片分别开来。这样,G和D构成了一个动态的“博弈过程”。

生成器

使用随机生成的向量作为生成器的输入即可,这里面的随机输入最好是满足常见分布,比如均值分布,高斯分布等。

判别器

判别器的输入为图片,输出为图片的真伪标签,相当于是一个二分类的分类器。

GAN损失函数

\[ \min_{G}\max_D(D,G)=E_{x∼pdata}(x)[\log D(x)]+E_{z∼pz(z)}[\log (1−D(G(z)))] \]

判别器需要:尽可能把生成器生成的fake data判断为0(\(D(G(z))\rightarrow0\)),尽可能把真实的true data判断为1(\(D(x)\rightarrow1\));

生成器需要:尽可能骗过判别器,使得\(D(G(z))\rightarrow1\)

判别器相当于二分类器,可以用Binary Cross Entropy来判别分布的相似性,一个样本的损失计算如下: \[ H((x_1,y_1),D)=-y_1\log D(x_1)-(1-y_1)\log(1-D(x_1)) \] 一个训练得很好的GAN,生成器生成的分布\(G(z)\)很接近真实分布,判别器难以区分生成器生成的数据和真实数据,输出的概率接近0.5(相当于蒙输入是真还是假),所以\(y_i=\frac{1}{2}\)\[ H((x_i,y_i)_{i=1}^{∞},D)=-\frac{1}{2}E_{x∼pdata}(x)[\log D(x)]-\frac{1}{2}E_{z∼pz(z)}[\log(1−D(G(z)))] \]

训练GAN

  1. 在噪声数据分布中随机采样,输入生成模型,得到一组假数据,记为\(D(z)\)
  2. 在真实数据分布中随机采样,作为真实数据,记做\(x\)
  3. 更新判别模型:判别网络的损失函数分为两个部分,因此将前两步产生的数据分别作为判别网络的输入(因此判别模型的输入为两类数据,真/假),分别进行前向传播和反向传播,再累加计算所得的损失函数;
  4. 更新生成模型:保持D不变,进行前向传播和反向传播;

特别注意:步骤里判别器和生成器是需要分开更新的,但在计算两个损失函数时都用到了D和G,因此进行前向传播时,需要进行特别的操作,防止在更新某个模型的时候也更新了另外一个模型(这在代码分析时会特别指出。)

GAN的不足

  1. 可解释性差,生成模型的分布 Pg(G)没有显式的表达
  2. 比较难训练,D与G之间需要很好的同步(例如D更新k次而G更新一次),GAN模型被定义为极小极大问题,没有损失函数,在训练过程中很难区分是否正在取得进展。GAN的学习过程可能发生崩溃问题(collapse problem),生成器开始退化,总是生成同样的样本点,无法继续学习。当生成模型崩溃时,判别模型也会对相似的样本点指向相似的方向,训练无法继续。
  3. 网络难以收敛,目前所有的理论都认为GAN应该在纳什均衡上有很好的表现,但梯度下降只有在凸函数的情况下才能保证实现纳什均衡。
  4. 训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但在实践中它还是比训练玻尔兹曼机稳定的多
  5. 它很难去学习生成离散的数据,就像文本
  6. 相比玻尔兹曼机,GANs很难根据一个像素值去猜测另外一个像素值,GANs天生就是做一件事的,那就是一次产生所有像素, 你可以用BiGAN来修正这个特性,它能让你像使用玻尔兹曼机一样去使用Gibbs采样来猜测缺失值

DCGAN原理

DCGAN网络结构

DCGAN 的判别器和生成器都使用了卷积神经网络(CNN)来替代GAN 中的多层感知机,同时为了使整个网络可微,拿掉了CNN 中的池化层,另外将全连接层以全局池化层替代以减轻计算量。

生成器G 将一个100 维的噪音向量扩展成64 * 64 * 3 的矩阵输出,整个过程采用的是微步卷积(fractionally-strided convolutions)的方式。

关于不同卷积方式可参考:反卷积Deconvolutionconv_arithmetic

DCGAN代码实现

基本配置

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 2021
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

数据加载和预处理

需要注意,transforms的输入为[h,w,c],输出为[c,h,w]。

import torch
from torch.utils.data import Dataset
from PIL import Image
class Mydata(Dataset):
def __init__(self,transform=None):
data1=np.load(r'人脸关键点检测挑战赛_数据集\train.npy')
data2=np.load(r'人脸关键点检测挑战赛_数据集\test.npy')
self.data=np.concatenate((data1,data2),2)
self.data=self.data.astype('uint8')
self.transform=transform
def __getitem__(self, idx):
img=Image.fromarray(self.data[:,:,idx]).convert('RGB')
if self.transform is not None:
img=self.transform(img)
return img
def __len__(self):
return self.data.shape[-1]
def collate_fn(self,batch):
#……
pass
image_size = 64
batch_size = 128
ngpu=1
nc=3
transform_func=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5 for _ in range(nc)], [0.5 for _ in range(nc)])])
dataset=Mydata(transform= transform_func)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()

模型设计

# Root directory for dataset
dataroot = "data/celeba"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)

Generator

# Generator Code

class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)

def forward(self, input):
return self.main(input)

Discriminator

class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

def forward(self, input):
return self.main(input)

模型和优化器初始化

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.02.
netG.apply(weights_init)

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netD.apply(weights_init)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

训练

主要步骤:

  1. 梯度清零;
  2. 前向传播;
  3. 计算损失;
  4. 后向传播;
  5. 参数更新。

具体来说:

  1. 生成器先前向传播生成数据;
  2. 判别器先训练:
    1. 判别器梯度清零;
    2. 真实分布采集的数据:前向传播、计算损失\(-\frac{1}{2}E_{x∼pdata}(x)[\log D(x)]\)、后向传播;
    3. 利用生成器生成的数据(这里代码需要特别注意):前向传播、计算损失\(-\frac{1}{2}E_{z∼pz(z)}[\log(1−D(G(z)))]\)、后向传播;
    4. 更新D的参数;
  3. 生成器再训练:
    1. 生成器梯度清零;
    2. 利用生成器生成的数据输入已经更新的D中计算损失、后向传播;
    3. 更新G的参数。
### 训练
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

from torch.utils.tensorboard import SummaryWriter

tb_writer = SummaryWriter(log_dir='log')
# 将模型写入tensorboard
init_input = torch.zeros((1, 100,image_size, image_size), device=device)
tb_writer.add_graph(netG, init_input)

init_input = torch.zeros((1, nc,image_size, image_size), device=device)
tb_writer.add_graph(netD, init_input)

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):

############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
# if epoch%10==0:
netD.zero_grad()
# Format batch
real_cpu = data.to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()

## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
errD_fake.backward()
D_G_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()

############################
# (2) Update G network: maximize log(D(G(z)))
###########################

netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()

# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
tags=['G_loss','D_loss','D_x', 'D_G_z1', 'D_G_z2']
for tag,value in zip(tags,[errG.item(),errD.item(),D_x, D_G_z1, D_G_z2]):
tb_writer.add_scalar('Train\%s'%tag, value, iters)
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): #可视化效果
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img=vutils.make_grid(fake, padding=2, normalize=True)
plt.imshow(img.permute((1, 2, 0)))
plt.show()
iters += 1
if epoch%100==0: #保存模型
num_params_D = sum(p.numel() for p in netD.parameters())
num_params_G = sum(p.numel() for p in netG.parameters())
checkpoint = {
'D_model_state_dict': netD.state_dict(), #*模型参数
'G_model_state_dict': netG.state_dict(), # *模型参数
'D_optimizer_state_dict': optimizerD.state_dict(), #*优化器参数
'G_optimizer_state_dict': optimizerG.state_dict(), # *优化器参数
'epoch': epoch,
'D_num_params': num_params_D,
'G_num_params': num_params_G,
}
torch.save(checkpoint, os.path.join(os.path.dirname(os.path.abspath(__file__))+'\\models', 'best_checkpoint.pt'))
print('end')

# path_checkpoint = "G:/ifwind/GAN/models/best_checkpoint.pt" # 断点路径
# checkpoint = torch.load(path_checkpoint) # 加载断点
# model = Generator(ngpu=1)
# model.load_state_dict(checkpoint['G_model_state_dict']) # 加载模型可学习参数
# model=model.to(device)
# with torch.no_grad():
# fake = model(fixed_noise)
#
# img=vutils.make_grid(fake.cpu(), padding=2, normalize=True)
plt.imshow(img_list[-1].permute((1, 2, 0)))
plt.show()

之前强调的判别器和生成器参数更新的特别注意在代码中体现为:

# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
...
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)

首先介绍一下.detach()方法,该方法截断node反向传播的梯度流,将某个node变成不需要梯度的Varibale,因此当反向传播经过这个node时,梯度就不会从这个node往前面传播

具体来说,在第一个output中,传入判别器的图像由生成器生成,并且利用.detach()做了截断,目的是“截断反传的梯度流”,使得该loss更新D时不要影响到 G。相当于用 fake_AB = fake_B, fake_AB.detach()从而让梯度不要通过 fake_AB反传到netG中。然后利用optimizerD更新D的参数。

而在第二个output中,需要更新G,此时不需要对梯度进行截断,这样可以更新G的梯度。利用optimizerG更新G的参数。

有博主介绍了其他几种GAN的更新策略,可参考:pytorch训练GAN时的detach()

训练输出效果

真实数据 噪声 噪声生成的数据
image-20211205153717683

预测

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

参考资料

图解 生成对抗网络GAN 原理 超详解

机器学习“判定模型”和“生成模型”有什么区别? - politer的回答 - 知乎

DCGAN TUTORIAL

从头开始GAN【论文】(二) —— DCGAN

反卷积Deconvolution

https://blog.csdn.net/Hungryof/article/details/78035332

pytorch训练GAN时的detach()