Gen-ai 简明教程
Training a Generative Adversarial Network (GANs)
我们探索了生成对抗网络的结构及其工作原理。在本章中,我们将选取一个实际示例来说明如何实现和训练 GAN 生成手写数字,这些数字与 MNIST 数据集中的一样。我们将在此示例中使用 Python 以及 TensorFlow 和 Keras。
Process of Training a Generative Adversarial Network
GAN 的培训涉及迭代优化生成器模型和判别器模型。让我们通过以下步骤了解生成对抗网络 (GAN) 的训练流程:
Training and Building a GAN
在这里,我们将展示使用 Python 和 MNIST 数据集训练和构建 GAN 的逐步过程 -
Step 1: Setting Up the Environment
在开始之前,我们需要使用必要的库来设置 Python 环境。确保您的计算机上已安装 TensorFlow 和 Keras。您可以使用 pip 如下安装它们 -
pip install tensorflow
Step 2: Import Necessary Libraries
我们需要导入必要的库 -
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
Step 3: Load and Preprocess the MNIST Dataset
MNIST 数据集包含 60,000 张手写数字训练图像和 10,000 张测试图像,每张图像大小为 28x28 像素。我们将像素值归一化到范围 [-1, 1] 以提高训练效率 -
# Load the dataset
(x_train, _), (_, _) = mnist.load_data()
# Normalize the images to [-1, 1]
x_train = (x_train - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)
# Set batch size and buffer size
BUFFER_SIZE = 60000
BATCH_SIZE = 256
Step 4: Create the Generator and Discriminator Models
生成器从随机噪声创建伪造图像,判别器尝试区分真实和伪造图像。
生成器模型将随机噪声向量作为输入,并通过一系列层对其进行转换,生成伪造图像 -
def build_generator():
model = models.Sequential()
model.add(layers.Dense(256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Dense(512, use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Dense(28 * 28 * 1, use_bias=False, activation='tanh'))
model.add(layers.Reshape((28, 28, 1)))
return model
generator = build_generator()
判别器模型将图像作为输入(真实或生成的),并输出一个概率值,表示图像真实还是伪造 -
def build_discriminator():
model = models.Sequential()
model.add(layers.Flatten(input_shape=(28, 28, 1)))
model.add(layers.Dense(512))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Dense(256))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Dense(1, activation='sigmoid'))
return model
discriminator = build_discriminator()
Step 5: Define Loss Functions and Optimizers
在此步骤中,我们将对生成器和判别器都使用二元交叉熵损失。生成器的目的是最大化判别器犯错的概率,而判别器的目的是最小化其分类错误。
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
Step 6: Define the Training Loop
GAN 的训练过程涉及迭代训练生成器和判别器。在这里,我们将定义一个训练步骤,包括生成伪造图像、计算损失以及使用反向传播更新模型权重。
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, 100])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
print(f'Epoch {epoch+1} completed')
Step 7: Prepare the Dataset and Train the GAN
接下来,我们将通过对 MNIST 图像进行洗牌和批处理来准备数据集,然后开始训练过程。
# Prepare the dataset for training
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# Train the GAN
EPOCHS = 50
train(train_dataset, EPOCHS)
Step 8: Generate and Display Images
现在,在训练 GAN 之后,我们可以生成并显示生成器创建的新图像。它包括创建随机噪声、将它输入生成器,并显示生成图像。
def generate_and_save_images(model, epoch, test_input):
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(7.50, 3.50))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
seed = tf.random.normal([16, 100])
generate_and_save_images(generator, EPOCHS, seed)
实现后,运行此代码时,您将获得以下输出 -