VAE:原理+代码全解析
2026/6/10 6:53:59 网站建设 项目流程

变分自编码器(Variational Autoencoder,VAE)是深度学习中经典的生成模型之一,它结合了自编码器的结构和变分推断的思想,既能完成数据压缩,又能实现数据生成。本文将从原理到代码,一步步拆解VAE的核心逻辑。

一、VAE的核心思想

VAE的本质是通过学习数据的潜在分布,实现从低维隐空间到高维数据空间的映射。和传统自编码器不同,VAE不是直接学习输入到隐向量的确定性映射,而是学习隐向量的概率分布,这也是它能生成新数据的关键。

1. 传统自编码器的局限

传统自编码器由编码器和解码器组成:编码器将输入数据压缩成固定维度的隐向量,解码器再将隐向量还原为输入数据。但这种结构的隐空间是离散且无规律的,无法通过采样隐向量生成新数据——比如在两个隐向量之间插值,可能得到无意义的结果。

2. VAE的改进:引入概率分布

VAE对编码器做了修改:不再输出固定的隐向量,而是输出隐向量的均值μ和方差σ²(为了计算方便,通常输出logσ²,避免方差为负)。然后从这个正态分布N(μ, σ²)中采样得到隐向量z,再输入解码器还原数据。

这个过程可以用两个核心步骤概括:

  • 编码过程:输入x → 编码器输出μ和logσ² → 采样得到z ~ N(μ, σ²)
  • 解码过程:z → 解码器输出重构数据x̂

3. VAE的损失函数

VAE的损失由两部分组成:重构损失和KL散度损失。

(1)重构损失

衡量解码器输出的重构数据x̂和原始输入x的差异,通常用交叉熵损失(针对图像等离散数据)或均方误差(针对连续数据):
Lrecon=−Ez∼q(z∣x)[log⁡p(x∣z)]L_{recon} = -\mathbb{E}_{z \sim q(z|x)}[\log p(x|z)]Lrecon=Ezq(zx)[logp(xz)]
简单来说,就是让重构数据尽可能接近原始数据。

(2)KL散度损失

KL散度用于衡量编码器输出的分布q(z|x)和预设的先验分布p(z)(通常设为标准正态分布N(0,1))之间的差异:
LKL=DKL(q(z∣x)∣∣p(z))=12∑i=1d(μi2+σi2−log⁡σi2−1)L_{KL} = D_{KL}(q(z|x) || p(z)) = \frac{1}{2}\sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - \log\sigma_i^2 - 1)LKL=DKL(q(zx)∣∣p(z))=21i=1d(μi2+σi2logσi21)
这部分损失的作用是约束隐空间的分布尽可能接近标准正态分布,保证隐空间的连续性和规律性,这样在隐空间中采样就能生成有意义的数据。

最终VAE的总损失为:
L=Lrecon+LKLL = L_{recon} + L_{KL}L=Lrecon+LKL

二、重参数化技巧

这里有个关键问题:如果直接从N(μ, σ²)中采样z,反向传播时梯度无法通过采样操作传递(因为采样是随机过程,不可导)。为了解决这个问题,VAE引入了重参数化技巧

将采样过程改写为:
z=μ+σ⊙ϵ,ϵ∼N(0,1)z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim N(0,1)z=μ+σϵ,ϵN(0,1)
其中⊙表示元素-wise乘法。这样一来,采样的随机性转移到了ε上,而μ和σ是编码器的输出,梯度可以通过μ和σ反向传播,解决了不可导的问题。

三、PyTorch代码实现

1. 定义VAE模型

classVAE(nn.Module):def__init__(self,input_dim=784,hidden_dim=256,latent_dim=20):super().__init__()self.encoder=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,hidden_dim),nn.ReLU())self.fc_mu=nn.Linear(hidden_dim,latent_dim)self.fc_logvar=nn.Linear(hidden_dim,latent_dim)self.decoder=nn.Sequential(nn.Linear(latent_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,input_dim),nn.Sigmoid())defencode(self,x):h=self.encoder(x)returnself.fc_mu(h),self.fc_logvar(h)defreparameterize(self,mu,log_var):std=torch.exp(0.5*log_var)returnmu+torch.randn_like(std)*stddefforward(self,x):mu,log_var=self.encode(x)z=self.reparameterize(mu,log_var)returnself.decode(z),mu,log_var

2. 损失函数

bce_loss=nn.BCELoss(reduction='sum')defloss_function(x_recon,x,mu,log_var):recon_loss=bce_loss(x_recon,x)kl_loss=-0.5*torch.sum(1+log_var-mu.pow(2)-log_var.exp())returnrecon_loss+kl_loss

3. 训练

transform=transforms.Compose([transforms.ToTensor()])train_dataset=datasets.MNIST(root='./data',train=True,download=True,transform=transform)train_loader=DataLoader(train_dataset,batch_size=128,shuffle=True)device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')model=VAE().to(device)optimizer=optim.Adam(model.parameters(),lr=1e-3)forepochinrange(50):total_loss=0fordata,_intrain_loader:data=data.view(-1,784).to(device)optimizer.zero_grad()x_recon,mu,log_var=model(data)loss=loss_function(x_recon,data,mu,log_var)loss.backward()total_loss+=loss.item()optimizer.step()print(f'Epoch{epoch+1}, Avg Loss:{total_loss/len(train_loader.dataset):.4f}')

4. 生成新数据

model.eval()withtorch.no_grad():z=torch.randn(25,20).to(device)generated_imgs=model.decode(z).cpu().numpy()

⚠️注意:本文仅为学习和理解算法进行 demo 代码实现,线上和生产环境不建议使用。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询