训练生成对抗网络 (GAN)

我们探索了生成对抗网络的架构及其工作原理。在本章中,我们将通过一个实际示例来演示如何实现和训练 GAN 来生成手写数字,与 MNIST 数据集中的数字相同。我们将在此示例中使用 Python 以及 TensorFlow 和 Keras。

训练生成对抗网络的过程

GAN 的训练涉及迭代优化生成器模型和鉴别器模型。让我们通过以下步骤了解生成对抗网络 (GAN) 的训练过程:

初始化

  • 该过程从两个神经网络开始:生成器网络 (G) 和鉴别器网络 (D)。
  • 生成器以随机种子或噪声向量作为输入并生成生成的样本。
  • 鉴别器以真实数据样本或生成的样本作为输入,并将它们分类为真实或虚假。

生成虚假数据

  • 将随机噪声向量输入生成器网络。
  • 生成器处理此噪声并输出旨在类似于真实数据的生成数据样本。

生成器训练

  • 首先,它从输入的随机噪声中生成虚假数据。
  • 然后使用鉴别器的输出计算生成器的损失。
  • 最后,它更新生成器的权重以最小化损失。

鉴别器训练

  • 首先,它获取一批真实数据和一批假数据。
  • 然后,它计算真实数据和假数据的鉴别器损失。
  • 最后,它更新鉴别器的权重以最小化损失。

迭代训练

  • 重复步骤 2 到 4。在每次迭代中,生成器和鉴别器都会交替训练并尝试提高彼此的性能。
  • 这种交替优化持续到生成器生成与真实数据相同的数据并且鉴别器无法再可靠地区分真实数据和假数据数据。

训练和构建 GAN

在这里,我们将展示使用 Python 和 MNIST 数据集训练和构建 GAN 的分步过程 −

步骤 1:设置环境

在开始之前,我们需要使用必要的库设置 Python 环境。确保您的计算机上安装了 TensorFlow 和 Keras。您可以使用 pip 安装它们,如下所示 −

pip install tensorflow

步骤 2:导入必要的库

我们需要导入必要的库 −

import numpy as np
import tensorflow as tf
from tensorflow.keras import layer, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

步骤 3:加载和预处理 MNIST 数据集

MNIST 数据集包含 60,000 张训练图像和 10,000 张手写数字测试图像,每张图像大小为 28x28 像素。我们将像素值标准化为 [-1, 1] 范围,以使训练更有效率 −

# 加载数据集
(x_train, _), (_, _) = mnist.load_data()

# 将图像标准化为 [-1, 1]
x_train = (x_train - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)

# 设置批量大小和缓冲区大小
BUFFER_SIZE = 60000
BATCH_SIZE = 256

步骤 4:创建生成器和鉴别器模型

生成器从随机噪声中创建假图像,鉴别器则尝试区分真假图像。

实现生成器模型

生成器模型将随机噪声向量作为输入,并通过一系列层对其进行转换以生成假图像 −

def build_generator():
   model = models.Sequential()
   model.add(layers.Dense(256, use_bias=False, input_shape=(100,)))
   model.add(layers.BatchNormalization())
   model.add(layers.LeakyReLU())
    
   model.add(layers.Dense(512, use_bias=False))
   model.add(layers.BatchNormalization())
   model.add(layers.LeakyReLU())
    
   model.add(layers.Dense(28 * 28 * 1, use_bias=False, activation='tanh'))
   model.add(layers.Reshape((28, 28, 1)))
    
   return model

generator = build_generator()

实现鉴别器模型

鉴别器模型将图像作为输入(真实或生成的)并输出一个概率值,指示图像是真实的还是假的 −

def build_discriminator():
   model = models.Sequential()
   model.add(layers.Flatten(input_shape=(28, 28, 1)))
   model.add(layers.Dense(512))
   model.add(layers.LeakyReLU())
   model.add(layers.Dropout(0.3))

   model.add(layers.Dense(256))
   model.add(layers.LeakyReLU())
   model.add(layers.Dropout(0.3))

   model.add(layers.Dense(1, activation='sigmoid'))

   return model

discriminator = build_discriminator()

第 5 步:定义损失函数和优化器

在此步骤中,我们将对生成器和鉴别器使用二元交叉熵损失。生成器旨在最大化鉴别器犯错的概率,而鉴别器旨在最小化其分类错误。

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(fake_output):
   return cross_entropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
   real_loss = cross_entropy(tf.ones_like(real_output), real_output)
   fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
   total_loss = real_loss + fake_loss
   return total_loss

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

第 6 步:定义训练循环

GAN 的训练过程涉及迭代训练生成器和鉴别器。在这里,我们将定义一个训练步骤,其中包括生成假图像、计算损失以及使用反向传播更新模型权重。

@tf.function
def train_step(images):
   noise = tf.random.normal([BATCH_SIZE, 100])
    
   with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)
        
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
        
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)
    
   gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
   gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
   generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
   discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
   for epoch in range(epochs):
      for image_batch in dataset:
         train_step(image_batch)
      print(f'Epoch {epoch+1} completed')

第 7 步:准备数据集并训练 GAN

接下来,我们将通过对 MNIST 图像进行混洗和批处理来准备数据集,然后开始训练过程。

# 准备用于训练的数据集
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# 训练 GAN
EPOCHS = 50
train(train_dataset, EPOCHS)

第 8 步:生成和显示图像

现在,在训练 GAN 之后,我们可以生成并显示由生成器创建的新图像。它涉及创建随机噪声、将其提供给生成器并显示生成的图像。

def generate_and_save_images(model, epoch, test_input):
   predictions = model(test_input, training=False)
    
   fig = plt.figure(figsize=(7.50, 3.50))
    
   for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i + 1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
    
   plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
   plt.show()

seed = tf.random.normal([16, 100])
generate_and_save_images(generator, EPOCHS, seed)

实施后,当您运行此代码时,您将获得以下输出 −

训练和构建 GAN

结论

使用 Python 训练 GAN 涉及几个关键步骤,例如设置环境、创建生成器和鉴别器模型、定义损失函数和优化器以及实施训练循环。通过遵循这些步骤,您可以训练自己的 GAN 并探索生成对抗网络的迷人世界。

在本章中,我们提供了使用 Python 编程语言构建和训练 GAN 的详细指南。我们使用 TensorFlow 和 Keras 库以及 MNIST 数据集作为示例。