变分自编码器(Variational Autoencoder,VAE)是深度学习中经典的生成模型之一,它结合了自编码器的结构和变分推断的思想,既能完成数据压缩,又能实现数据生成。本文将从原理到代码,一步步拆解VAE的核心逻辑。
一、VAE的核心思想
VAE的本质是通过学习数据的潜在分布,实现从低维隐空间到高维数据空间的映射。和传统自编码器不同,VAE不是直接学习输入到隐向量的确定性映射,而是学习隐向量的概率分布,这也是它能生成新数据的关键。
1. 传统自编码器的局限
传统自编码器由编码器和解码器组成:编码器将输入数据压缩成固定维度的隐向量,解码器再将隐向量还原为输入数据。但这种结构的隐空间是离散且无规律的,无法通过采样隐向量生成新数据——比如在两个隐向量之间插值,可能得到无意义的结果。
2. VAE的改进:引入概率分布
VAE对编码器做了修改:不再输出固定的隐向量,而是输出隐向量的均值μ和方差σ²(为了计算方便,通常输出logσ²,避免方差为负)。然后从这个正态分布N(μ, σ²)中采样得到隐向量z,再输入解码器还原数据。
这个过程可以用两个核心步骤概括:
- 编码过程:输入x → 编码器输出μ和logσ² → 采样得到z ~ N(μ, σ²)
- 解码过程:z → 解码器输出重构数据x̂
3. VAE的损失函数
VAE的损失由两部分组成:重构损失和KL散度损失。
(1)重构损失
衡量解码器输出的重构数据x̂和原始输入x的差异,通常用交叉熵损失(针对图像等离散数据)或均方误差(针对连续数据):
Lrecon=−Ez∼q(z∣x)[logp(x∣z)]L_{recon} = -\mathbb{E}_{z \sim q(z|x)}[\log p(x|z)]Lrecon=−Ez∼q(z∣x)[logp(x∣z)]
简单来说,就是让重构数据尽可能接近原始数据。
(2)KL散度损失
KL散度用于衡量编码器输出的分布q(z|x)和预设的先验分布p(z)(通常设为标准正态分布N(0,1))之间的差异:
LKL=DKL(q(z∣x)∣∣p(z))=12∑i=1d(μi2+σi2−logσi2−1)L_{KL} = D_{KL}(q(z|x) || p(z)) = \frac{1}{2}\sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - \log\sigma_i^2 - 1)LKL=DKL(q(z∣x)∣∣p(z))=21i=1∑d(μi2+σi2−logσi2−1)
这部分损失的作用是约束隐空间的分布尽可能接近标准正态分布,保证隐空间的连续性和规律性,这样在隐空间中采样就能生成有意义的数据。
最终VAE的总损失为:
L=Lrecon+LKLL = L_{recon} + L_{KL}L=Lrecon+LKL
二、重参数化技巧
这里有个关键问题:如果直接从N(μ, σ²)中采样z,反向传播时梯度无法通过采样操作传递(因为采样是随机过程,不可导)。为了解决这个问题,VAE引入了重参数化技巧:
将采样过程改写为:
z=μ+σ⊙ϵ,ϵ∼N(0,1)z = \mu + \sigma \odot \epsilon, \quad \epsilon \sim N(0,1)z=μ+σ⊙ϵ,ϵ∼N(0,1)
其中⊙表示元素-wise乘法。这样一来,采样的随机性转移到了ε上,而μ和σ是编码器的输出,梯度可以通过μ和σ反向传播,解决了不可导的问题。
三、PyTorch代码实现
1. 定义VAE模型
classVAE(nn.Module):def__init__(self,input_dim=784,hidden_dim=256,latent_dim=20):super().__init__()self.encoder=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,hidden_dim),nn.ReLU())self.fc_mu=nn.Linear(hidden_dim,latent_dim)self.fc_logvar=nn.Linear(hidden_dim,latent_dim)self.decoder=nn.Sequential(nn.Linear(latent_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,input_dim),nn.Sigmoid())defencode(self,x):h=self.encoder(x)returnself.fc_mu(h),self.fc_logvar(h)defreparameterize(self,mu,log_var):std=torch.exp(0.5*log_var)returnmu+torch.randn_like(std)*stddefforward(self,x):mu,log_var=self.encode(x)z=self.reparameterize(mu,log_var)returnself.decode(z),mu,log_var2. 损失函数
bce_loss=nn.BCELoss(reduction='sum')defloss_function(x_recon,x,mu,log_var):recon_loss=bce_loss(x_recon,x)kl_loss=-0.5*torch.sum(1+log_var-mu.pow(2)-log_var.exp())returnrecon_loss+kl_loss3. 训练
transform=transforms.Compose([transforms.ToTensor()])train_dataset=datasets.MNIST(root='./data',train=True,download=True,transform=transform)train_loader=DataLoader(train_dataset,batch_size=128,shuffle=True)device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')model=VAE().to(device)optimizer=optim.Adam(model.parameters(),lr=1e-3)forepochinrange(50):total_loss=0fordata,_intrain_loader:data=data.view(-1,784).to(device)optimizer.zero_grad()x_recon,mu,log_var=model(data)loss=loss_function(x_recon,data,mu,log_var)loss.backward()total_loss+=loss.item()optimizer.step()print(f'Epoch{epoch+1}, Avg Loss:{total_loss/len(train_loader.dataset):.4f}')4. 生成新数据
model.eval()withtorch.no_grad():z=torch.randn(25,20).to(device)generated_imgs=model.decode(z).cpu().numpy()⚠️注意:本文仅为学习和理解算法进行 demo 代码实现,线上和生产环境不建议使用。