新闻动态

Pytorch学习笔记DCGAN极简入门教程

发布日期:2022-01-21 16:02 | 文章来源:脚本之家

1.图片分类网络

这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片

2.图片生成网络

输入是一个随机噪声,输出是一张图片,使用的是反卷积层

相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了

首先是图片分类网络:

简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等

class Discriminator(nn.Module):
 def __init__(self, ngpu):
  super(Discriminator, self).__init__()
  self.ngpu = ngpu
  self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
  )
 def forward(self, input):
  return self.main(input)

重点是生成网络

代码如下,其实就是反卷积+bn+relu

class Generator(nn.Module):
 def __init__(self, ngpu):
  super(Generator, self).__init__()
  self.ngpu = ngpu
  self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
  )
 def forward(self, input):
  return self.main(input)

讲道理,以上两个网络都挺简单。

真正的重点到了,怎么训练

每一个step分为三个步骤:

  • 训练二分类网络
    1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
    2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络
  • 训练生成网络
    3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络

不多说直接上代码

for epoch in range(num_epochs):
 # For each batch in the dataloader
 for i, data in enumerate(dataloader, 0):
  ############################
  # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
  ###########################
  ## Train with all-real batch
  netD.zero_grad()
  # Format batch
  real_cpu = data[0].to(device)
  b_size = real_cpu.size(0)
  label = torch.full((b_size,), real_label, device=device)
  # Forward pass real batch through D
  output = netD(real_cpu).view(-1)
  # Calculate loss on all-real batch
  errD_real = criterion(output, label)
  # Calculate gradients for D in backward pass
  errD_real.backward()
  D_x = output.mean().item()
  ## Train with all-fake batch
  # Generate batch of latent vectors
  noise = torch.randn(b_size, nz, 1, 1, device=device)
  # Generate fake image batch with G
  fake = netG(noise)
  label.fill_(fake_label)
  # Classify all fake batch with D
  output = netD(fake.detach()).view(-1)
  # Calculate D's loss on the all-fake batch
  errD_fake = criterion(output, label)
  # Calculate the gradients for this batch
  errD_fake.backward()
  D_G_z1 = output.mean().item()
  # Add the gradients from the all-real and all-fake batches
  errD = errD_real + errD_fake
  # Update D
  optimizerD.step()
  ############################
  # (2) Update G network: maximize log(D(G(z)))
  ###########################
  netG.zero_grad()
  label.fill_(real_label)  # fake labels are real for generator cost
  # Since we just updated D, perform another forward pass of all-fake batch through D
  output = netD(fake).view(-1)
  # Calculate G's loss based on this output
  errG = criterion(output, label)
  # Calculate gradients for G
  errG.backward()
  D_G_z2 = output.mean().item()
  # Update G
  optimizerG.step()
  # Output training stats
  if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
  # Save Losses for plotting later
  G_losses.append(errG.item())
  D_losses.append(errD.item())
  # Check how the generator is doing by saving G's output on fixed_noise
  if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
 fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
  iters += 1

以上就是Pytorch学习笔记DCGAN极简入门教程的详细内容,更多关于Pytorch学习DCGAN入门教程的资料请关注本站其它相关文章!

版权声明:本站文章来源标注为YINGSOO的内容版权均为本站所有,欢迎引用、转载,请保持原文完整并注明来源及原文链接。禁止复制或仿造本网站,禁止在非www.yingsoo.com所属的服务器上建立镜像,否则将依法追究法律责任。本站部分内容来源于网友推荐、互联网收集整理而来,仅供学习参考,不代表本站立场,如有内容涉嫌侵权,请联系alex-e#qq.com处理。

相关文章

实时开通

自选配置、实时开通

免备案

全球线路精选!

全天候客户服务

7x24全年不间断在线

专属顾问服务

1对1客户咨询顾问

在线
客服

在线客服:7*24小时在线

客服
热线

400-630-3752
7*24小时客服服务热线

关注
微信

关注官方微信
顶部