GAN训练中,如何用detach()实现的梯度截断?

摘要:我最近在学使用Pytorch写GAN代码,发现有些代码在训练部分细节有略微不同,其中有的人用到了detach()函数截断梯度流,有的人没用detch(),取而代之的是在损失函数在反向传播过程中将backward(retain_graph=T
我最近在学使用Pytorch写GAN代码,发现有些代码在训练部分细节有略微不同,其中有的人用到了detach()函数截断梯度流,有的人没用detch(),取而代之的是在损失函数在反向传播过程中将backward(retain_graph=True),本文通过两个 gan 的代码,介绍它们的作用,并分析,不同的更新策略对程序效率的影响。   这两个 GAN 的实现中,有两种不同的训练策略: 先训练判别器(discriminator),再训练生成器(generator),这是原始论文Generative Adversarial Networks中的算法 先训练generator,再训练discriminator   为了减少网络垃圾,GAN的原理网上一大堆,我这里就不重复赘述了,想要详细了解GAN原理的朋友,可以参考我专题文章:神经网络结构:生成式对抗网络(GAN)。 需要了解的知识:   detach():截断node反向传播的梯度流,将某个node变成不需要梯度的Varibale,因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。 更新策略   我们直接下面进入本文正题,即,在 pytorch 中,detach 和 retain_graph 是干什么用的?本文将借助三段 GAN 的实现代码,来举例介绍它们的作用。 先训练判别器,再训练生成器 策略一 我们分析循环中一个 step 的代码: valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真实标签,都是1 fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0 # ######################## # 训练判别器 # # ######################## real_imgs = imgs.to(device) # 真实图片 z = torch.randn((imgs.shape[0], 100)).to(device) # 噪声 gen_imgs = generator(z) # 从噪声中生成假数据 pred_gen = discriminator(gen_imgs) # 判别器对假数据的输出 pred_real = discriminator(real_imgs) # 判别器对真数据的输出 optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零 real_loss = adversarial_loss(pred_real, valid) # 判别器对真实样本的损失 fake_loss = adversarial_loss(pred_gen, fake) # 判别器对假样本的损失 d_loss = (real_loss + fake_loss) / 2 # 两项损失相加取平均 # 下面这行代码十分重要,将在正文着重讲解 d_loss.backward(retain_graph=True) # retain_graph=True 十分重要,否则计算图内存将会被释放 optimizer_D.step() # 判别器参数更新 # ######################## # 训练生成器 # # ######################## g_loss = adversarial_loss(pred_gen, valid) # 生成器的损失函数 optimizer_G.zero_grad() # 生成器参数梯度归零 g_loss.backward() # 生成器的损失函数梯度反向传播 optimizer_G.step() # 生成器参数更新 代码讲解   鉴别器的损失函数d_loss是由real_loss和fake_loss组成的,而fake_loss又是noise经过generator来的。这样一来我们对d_loss进行反向传播,不仅会计算discriminator 的梯度还会计算generator 的梯度(虽然这一步optimizer_D.step()只更新 discriminator 的参数),因此下面在更新generator参数时,要先将generator参数的梯度清零,避免受到discriminator loss 回传过来的梯度影响。
阅读全文