from __future__ import print_function #%matplotlib inline import argparse import os import random import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation from IPython.display import HTML
# Set random seed for reproducibility manualSeed = 2021 #manualSeed = random.randint(1, 10000) # use if you want new results print("Random Seed: ", manualSeed) random.seed(manualSeed) torch.manual_seed(manualSeed)
# Create the dataloader dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
# Decide which device we want to run on device = torch.device("cuda:0"if (torch.cuda.is_available() and ngpu > 0) else"cpu")
# Plot some training images real_batch = next(iter(dataloader)) plt.figure(figsize=(8,8)) plt.axis("off") plt.title("Training Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0))) plt.show()
模型设计
# Root directory for dataset dataroot = "data/celeba" # Number of workers for dataloader workers = 2 # Batch size during training batch_size = 128 # Spatial size of training images. All images will be resized to this # size using a transformer. image_size = 64 # Number of channels in the training images. For color images this is 3 nc = 3 # Size of z latent vector (i.e. size of generator input) nz = 100 # Size of feature maps in generator ngf = 64 # Size of feature maps in discriminator ndf = 64 # Number of training epochs num_epochs = 5 # Learning rate for optimizers lr = 0.0002 # Beta1 hyperparam for Adam optimizers beta1 = 0.5 # Number of GPUs available. Use 0 for CPU mode. ngpu = 1
# custom weights initialization called on netG and netD defweights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0)
Generator
# Generator Code
classGenerator(nn.Module): def__init__(self, ngpu): super(Generator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), # state size. (ngf*8) x 4 x 4 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), # state size. (ngf*4) x 8 x 8 nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), # state size. (ngf*2) x 16 x 16 nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), nn.Tanh() # state size. (nc) x 64 x 64 )
defforward(self, input): return self.main(input)
Discriminator
classDiscriminator(nn.Module): def__init__(self, ngpu): super(Discriminator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( # input is (nc) x 64 x 64 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf) x 32 x 32 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*2) x 16 x 16 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*4) x 8 x 8 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*8) x 4 x 4 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() )
defforward(self, input): return self.main(input)
模型和优化器初始化
# Create the generator netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired if (device.type == 'cuda') and (ngpu > 1): netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights # to mean=0, stdev=0.02. netG.apply(weights_init)
# Create the Discriminator netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired if (device.type == 'cuda') and (ngpu > 1): netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights # to mean=0, stdev=0.2. netD.apply(weights_init)
# Initialize BCELoss function criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize # the progression of the generator fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training real_label = 1. fake_label = 0.
# Setup Adam optimizers for both G and D optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
print("Starting Training Loop...") # For each epoch for epoch inrange(num_epochs): # For each batch in the dataloader for i, data inenumerate(dataloader, 0):
############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### ## Train with all-real batch # if epoch%10==0: netD.zero_grad() # Format batch real_cpu = data.to(device) b_size = real_cpu.size(0) label = torch.full((b_size,), real_label, dtype=torch.float, device=device) # Forward pass real batch through D output = netD(real_cpu).view(-1) # Calculate loss on all-real batch errD_real = criterion(output, label) # Calculate gradients for D in backward pass errD_real.backward() D_x = output.mean().item()
## Train with all-fake batch # Generate batch of latent vectors noise = torch.randn(b_size, nz, 1, 1, device=device) # Generate fake image batch with G fake = netG(noise) label.fill_(fake_label) # Classify all fake batch with D output = netD(fake.detach()).view(-1) # Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) # Calculate the gradients for this batch, accumulated (summed) with previous gradients errD_fake.backward() D_G_z1 = output.mean().item() # Compute error of D as sum over the fake and the real batches errD = errD_real + errD_fake # Update D optimizerD.step()
############################ # (2) Update G network: maximize log(D(G(z))) ###########################
netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view(-1) # Calculate G's loss based on this output errG = criterion(output, label) # Calculate gradients for G errG.backward() D_G_z2 = output.mean().item() # Update G optimizerG.step()
# Output training stats if i % 50 == 0: print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later G_losses.append(errG.item()) D_losses.append(errD.item()) tags=['G_loss','D_loss','D_x', 'D_G_z1', 'D_G_z2'] for tag,value inzip(tags,[errG.item(),errD.item(),D_x, D_G_z1, D_G_z2]): tb_writer.add_scalar('Train\%s'%tag, value, iters) # Check how the generator is doing by saving G's output on fixed_noise if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): #可视化效果 with torch.no_grad(): fake = netG(fixed_noise).detach().cpu() img=vutils.make_grid(fake, padding=2, normalize=True) plt.imshow(img.permute((1, 2, 0))) plt.show() iters += 1 if epoch%100==0: #保存模型 num_params_D = sum(p.numel() for p in netD.parameters()) num_params_G = sum(p.numel() for p in netG.parameters()) checkpoint = { 'D_model_state_dict': netD.state_dict(), #*模型参数 'G_model_state_dict': netG.state_dict(), # *模型参数 'D_optimizer_state_dict': optimizerD.state_dict(), #*优化器参数 'G_optimizer_state_dict': optimizerG.state_dict(), # *优化器参数 'epoch': epoch, 'D_num_params': num_params_D, 'G_num_params': num_params_G, } torch.save(checkpoint, os.path.join(os.path.dirname(os.path.abspath(__file__))+'\\models', 'best_checkpoint.pt')) print('end')
# Classify all fake batch with D output = netD(fake.detach()).view(-1) ... # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view(-1)
# Grab a batch of real images from the dataloader real_batch = next(iter(dataloader))
# Plot the real images plt.figure(figsize=(15,15)) plt.subplot(1,2,1) plt.axis("off") plt.title("Real Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch plt.subplot(1,2,2) plt.axis("off") plt.title("Fake Images") plt.imshow(np.transpose(img_list[-1],(1,2,0))) plt.show()