GAN训练中,如何用detach()实现的梯度截断?
摘要:我最近在学使用Pytorch写GAN代码,发现有些代码在训练部分细节有略微不同,其中有的人用到了detach()函数截断梯度流,有的人没用detch(),取而代之的是在损失函数在反向传播过程中将backward(retain_graph=T
我最近在学使用Pytorch写GAN代码,发现有些代码在训练部分细节有略微不同,其中有的人用到了detach()函数截断梯度流,有的人没用detch(),取而代之的是在损失函数在反向传播过程中将backward(retain_graph=True),本文通过两个 gan 的代码,介绍它们的作用,并分析,不同的更新策略对程序效率的影响。
这两个 GAN 的实现中,有两种不同的训练策略:
先训练判别器(discriminator),再训练生成器(generator),这是原始论文Generative Adversarial Networks中的算法
先训练generator,再训练discriminator
为了减少网络垃圾,GAN的原理网上一大堆,我这里就不重复赘述了,想要详细了解GAN原理的朋友,可以参考我专题文章:神经网络结构:生成式对抗网络(GAN)。
需要了解的知识:
detach():截断node反向传播的梯度流,将某个node变成不需要梯度的Varibale,因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。
更新策略
我们直接下面进入本文正题,即,在 pytorch 中,detach 和 retain_graph 是干什么用的?本文将借助三段 GAN 的实现代码,来举例介绍它们的作用。
先训练判别器,再训练生成器
策略一
我们分析循环中一个 step 的代码:
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真实标签,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0
# ########################
# 训练判别器 #
# ########################
real_imgs = imgs.to(device) # 真实图片
z = torch.randn((imgs.shape[0], 100)).to(device) # 噪声
gen_imgs = generator(z) # 从噪声中生成假数据
pred_gen = discriminator(gen_imgs) # 判别器对假数据的输出
pred_real = discriminator(real_imgs) # 判别器对真数据的输出
optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零
real_loss = adversarial_loss(pred_real, valid) # 判别器对真实样本的损失
fake_loss = adversarial_loss(pred_gen, fake) # 判别器对假样本的损失
d_loss = (real_loss + fake_loss) / 2 # 两项损失相加取平均
# 下面这行代码十分重要,将在正文着重讲解
d_loss.backward(retain_graph=True) # retain_graph=True 十分重要,否则计算图内存将会被释放
optimizer_D.step() # 判别器参数更新
# ########################
# 训练生成器 #
# ########################
g_loss = adversarial_loss(pred_gen, valid) # 生成器的损失函数
optimizer_G.zero_grad() # 生成器参数梯度归零
g_loss.backward() # 生成器的损失函数梯度反向传播
optimizer_G.step() # 生成器参数更新
代码讲解
鉴别器的损失函数d_loss是由real_loss和fake_loss组成的,而fake_loss又是noise经过generator来的。这样一来我们对d_loss进行反向传播,不仅会计算discriminator 的梯度还会计算generator 的梯度(虽然这一步optimizer_D.step()只更新 discriminator 的参数),因此下面在更新generator参数时,要先将generator参数的梯度清零,避免受到discriminator loss 回传过来的梯度影响。
