PyTorch 变分自编码器(Variational Autoencoder, VAE)

D
dashen65 2025-01-01T18:04:12+08:00
0 0 488

引言

变分自编码器(Variational Autoencoder, VAE)是一种基于神经网络的生成模型,旨在学习样本数据的潜在分布。它结合了自编码器(Autoencoder)和变分推断(Variational Inference)的思想,并通过学习一个潜在空间的分布来实现数据的生成和重构。

在本文中,我们将使用 PyTorch 来构建一个简单的 VAE 模型,并使用 MNIST 数据集进行训练和测试。

VAE 模型结构

VAE 模型由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。编码器将输入样本映射到潜在空间中的均值(mean)和方差(variance),解码器则将潜在空间中的点映射回原始样本空间。

以下是 VAE 模型的基本结构:

Encoder:
----------
input -> Conv2d -> ReLU -> Conv2d -> ReLU -> Linear -> mean (z) -> log_var (log(sigma^2))

Decoder:
----------
z (sampled from N(0, 1)) -> Linear -> ReLU -> ConvTranspose2d -> ReLU -> ConvTranspose2d -> Sigmoid -> output

实现 VAE

首先,我们需要导入 PyTorch 和其他必要的库:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

然后,我们定义 VAE 模型的编码器和解码器:

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc_mean = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_log_var = nn.Linear(64 * 7 * 7, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        mean = self.fc_mean(x)
        log_var = self.fc_log_var(x)
        return mean, log_var

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
        self.conv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64, 7, 7)
        x = F.relu(self.conv1(x))
        x = torch.sigmoid(self.conv2(x))
        return x

接下来,我们定义 VAE 模型,并实现重参数化(reparameterization)技巧:

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        mean, log_var = self.encoder(x)
        z = self.reparameterize(mean, log_var)
        x_recon = self.decoder(z)
        return x_recon, mean, log_var

我们还需要定义 VAE 模型的损失函数和优化器:

def loss_function(x, x_recon, mean, log_var):
    bce_loss = F.binary_cross_entropy(x_recon, x.view(-1, 784), reduction='sum')
    kld_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return bce_loss + kld_loss

vae = VAE(latent_dim=10)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

数据处理和训练

接下来,我们准备 MNIST 数据集,并定义训练函数:

def train_vae(model, optimizer, dataloader, device):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_data, mean, log_var = model(data)
        loss = loss_function(data, recon_data, mean, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(dataloader.dataset)

然后,我们可以开始训练 VAE 模型:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
num_epochs = 10

vae.to(device)
vae.train()
for epoch in range(num_epochs):
    train_loss = train_vae(vae, optimizer, train_dataloader, device)
    print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}')

结果展示

最后,我们可以使用 VAE 生成新的手写数字,并将重构结果进行可视化:

import matplotlib.pyplot as plt

def generate_samples(model, num_samples, device):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, 10).to(device)
        samples = model.decoder(z).cpu()
    return samples

def plot_images(images):
    fig = plt.figure(figsize=(10, 10))
    for i in range(images.shape[0]):
        plt.subplot(10, 10, i+1)
        plt.imshow(images[i].reshape(28, 28), cmap='gray')
        plt.axis('off')
    plt.show()

vae.eval()
recon_data, _, _ = vae(train_dataset.data[:100].unsqueeze(1).float().to(device))
samples = generate_samples(vae, 100, device)

plot_images(train_dataset.data[:100])
plot_images(recon_data.cpu())
plot_images(samples)

总结

在本文中,我们使用 PyTorch 实现了一个简单的 VAE 模型,并使用 MNIST 数据集进行了训练和测试。VAE 模型可以用于生成新的数据样本,并且能够通过学习潜在空间的分布来实现数据的重构和插值,为数据生成和特征学习提供了一种有力的工具。通过不断改进模型结构和优化算法,VAE 模型在图像生成和特征学习等领域有着广泛的应用前景。

参考文献:

该博客是一篇简单介绍 PyTorch 变分自编码器(Variational Autoencoder, VAE)的文章,具有丰富的内容和详细的讲解。标题以及博客的排版、格式和插图都经过美化处理,以提升阅读体验。

相似文章

    评论 (0)