Gen-ai 简明教程
Training a Generative Adversarial Network (GANs)
我们探索了生成对抗网络的结构及其工作原理。在本章中,我们将选取一个实际示例来说明如何实现和训练 GAN 生成手写数字,这些数字与 MNIST 数据集中的一样。我们将在此示例中使用 Python 以及 TensorFlow 和 Keras。
We explored the architecture of Generative Adversarial Networks and how they work. In this chapter, we will take a practical example to demonstrate how you can implement and train a GAN to generate handwritten digits, same as those in the MNIST dataset. We’ll use Python along with TensorFlow and Keras for this example.
Process of Training a Generative Adversarial Network
GAN 的培训涉及迭代优化生成器模型和判别器模型。让我们通过以下步骤了解生成对抗网络 (GAN) 的训练流程:
The training of GANs involves optimizing both the generator model and the discriminator model iteratively. Let’s understand the training process of a Generative Adversarial Network (GAN) using the following steps:
The process starts with two neural networks: the Generator Network (G) and the Discriminator Network (D).
The Generator takes a random seed or noise vector as input and produces generated samples.
The Discriminator takes either real data samples or generated samples as input and classifies them as real or fake.
Generating Fake Data
A random noise vector is fed into the Generator Network.
The Generator processes this noise and outputs generated data samples that are intended to resemble real data.
Generator Training
First it generates fake data from input random noise.
Then it calculates the generator’s loss using the discriminator’s output.
Finally, it updates the generator’s weights to minimize the loss.
Discriminator Training
First, it takes a batch of real data and a batch of fake data.
Then it calculates the discriminator’s loss for both real and fake data.
Finally, it updates the discriminator’s weights to minimize the loss.
Iterative Training
Repeat steps 2 to 4. During each iteration, both the Generator and Discriminator are alternately trained and try to improve each other’s performance.
This alternating optimization continues until the generator generates data that is identical to the real data and the discriminator can no longer reliably distinguish between real and fake data.
Training and Building a GAN
在这里,我们将展示使用 Python 和 MNIST 数据集训练和构建 GAN 的逐步过程 -
Here, we will show the step-by-step procedure of training and building a GAN using Python and the MNIST dataset −
Step 1: Setting Up the Environment
在开始之前,我们需要使用必要的库来设置 Python 环境。确保您的计算机上已安装 TensorFlow 和 Keras。您可以使用 pip 如下安装它们 -
Before we start, we need to set up our Python environment with the necessary libraries. Ensure you have TensorFlow and Keras installed on your computer. You can install them using pip as follows −
pip install tensorflow
Step 2: Import Necessary Libraries
我们需要导入必要的库 -
We need to import the essential libraries −
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
Step 3: Load and Preprocess the MNIST Dataset
MNIST 数据集包含 60,000 张手写数字训练图像和 10,000 张测试图像,每张图像大小为 28x28 像素。我们将像素值归一化到范围 [-1, 1] 以提高训练效率 -
The MNIST dataset consists of 60,000 training images and 10,000 testing images of handwritten digits, each of size 28x28 pixels. We will normalize the pixel values to the range [-1, 1] to make training more efficient −
# Load the dataset
(x_train, _), (_, _) = mnist.load_data()
# Normalize the images to [-1, 1]
x_train = (x_train - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=-1)
# Set batch size and buffer size
Step 4: Create the Generator and Discriminator Models
The generator creates fake images from random noise, and the discriminator attempts to distinguish between real and fake images.
生成器模型将随机噪声向量作为输入,并通过一系列层对其进行转换,生成伪造图像 -
The generator model takes a random noise vector as input and transforms it through a series of layers to produce a fake image −
def build_generator():
model = models.Sequential()
model.add(layers.Dense(256, use_bias=False, input_shape=(100,)))
model.add(layers.Dense(512, use_bias=False))
model.add(layers.Dense(28 * 28 * 1, use_bias=False, activation='tanh'))
model.add(layers.Reshape((28, 28, 1)))
return model
generator = build_generator()
判别器模型将图像作为输入(真实或生成的),并输出一个概率值,表示图像真实还是伪造 -
The discriminator model takes an image as input (either real or generated) and outputs a probability value indicating whether the image is real or fake −
def build_discriminator():
model = models.Sequential()
model.add(layers.Flatten(input_shape=(28, 28, 1)))
model.add(layers.Dense(1, activation='sigmoid'))
return model
discriminator = build_discriminator()
Step 5: Define Loss Functions and Optimizers
In this step, we will use binary cross-entropy loss for both the generator and discriminator. The generator aims to maximize the probability of the discriminator making a mistake, while the discriminator aims to minimize its classification error.
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
Step 6: Define the Training Loop
GAN 的训练过程涉及迭代训练生成器和判别器。在这里,我们将定义一个训练步骤,包括生成伪造图像、计算损失以及使用反向传播更新模型权重。
The training process for a GAN involves training the generator and discriminator iteratively. Here, we will define a training step that includes generating fake images, calculating losses, and updating the model weights using backpropagation.
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, 100])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
for image_batch in dataset:
print(f'Epoch {epoch+1} completed')
Step 7: Prepare the Dataset and Train the GAN
接下来,我们将通过对 MNIST 图像进行洗牌和批处理来准备数据集,然后开始训练过程。
Next, we will prepare the dataset by shuffling and batching the MNIST images and then we will start the training process.
# Prepare the dataset for training
train_dataset =
# Train the GAN
train(train_dataset, EPOCHS)
Step 8: Generate and Display Images
现在,在训练 GAN 之后,我们可以生成并显示生成器创建的新图像。它包括创建随机噪声、将它输入生成器,并显示生成图像。
Now, after training the GAN, we can generate and display new images created by the generator. It involves creating random noise, feeding it to the generator, and displaying the resulting images.
def generate_and_save_images(model, epoch, test_input):
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(7.50, 3.50))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
seed = tf.random.normal([16, 100])
generate_and_save_images(generator, EPOCHS, seed)
实现后,运行此代码时,您将获得以下输出 -
After implementation, when you run this code, you will get the following output −

使用 Python 训练 GAN 涉及几个关键步骤,例如设置环境、创建生成器和判别器模型、定义损失函数和优化器以及实现训练循环。通过遵循这些步骤,您可以训练自己的 GAN,并探索生成对抗网络的迷人世界。
Training a GAN using Python involves several key steps such as setting up the environment, creating the generator and discriminator models, defining loss functions and optimizers, and implementing the training loop. By following these steps, you can train your own GAN and explore the fascinating world of generative adversarial networks.
在本章中,我们提供了使用 Python 编程语言构建和训练 GAN 的详细指南。在我们的示例中,我们使用了 TensorFlow 和 Keras 库以及 MNIST 数据集。
In this chapter, we provided a detailed guide to building and training a GAN using Python programming language. We used TensorFlow and Keras libraries and the MNIST dataset for our example.