vae代码阅读
在阅读flux代码(pipeline_flux.Fluxpipeline类中的__call__方法最后几句)对diffusion采样得到的latent转换回vae输入的这段代码时,看到标蓝这段代码很疑惑,因此打算复习一下vae,研究一下为什么要乘以缩放系数以及进行偏置。
查阅网上博客后,有人说是因为pixel space变成latent space之后的值都特别大,因此需要一个缩放因子来让范围变小,同时进行偏置使得范围合理。 这是我认为比较正确的理解。
官方文档的解释:https://huggingface.co/docs/diffusers/main/en/api/models/autoencoderkl
1 |
|
原理解析
两个主要部分组成:
- 编码器(Encoder):将输入数据 $ x $ 映射到潜在空间 $ \mathbf{z} $。
- 解码器(Decoder):将潜在表示 $ \mathbf{z} $ 重构回原始数据空间 $ \mathbf{x}’ $。
VAE将数据生成过程建模为一个概率过程:
-
潜在变量的先验分布:假设潜在变量 $ \mathbf{z} $ 服从某个先验分布,通常选择标准正态分布:
-
生成模型:给定潜在变量 $ \mathbf{z} $,生成数据 $ \mathbf{x} $ 的条件分布:
其中,$ \mu(\mathbf{z}) $ 和 $ \sigma(\mathbf{z}) $ 由解码器网络参数化。
直接计算后验分布 $ p(\mathbf{z}|\mathbf{x}) $ 通常非常困难,因此VAE使用变分推断,通过引入一个可参数化的近似分布 $ q(\mathbf{z}|\mathbf{x}) $ 来逼近真实的后验分布。
为了训练模型,VAE最大化证据下界(Evidence Lower Bound, ELBO):
其中:
-
重构项:
- 衡量模型重构数据的能力。
-
正则化项:
- 衡量近似后验分布与先验分布之间的差异,确保潜在空间的连续性和规则性。
VAE的损失函数由两个部分组成:
-
重构损失(Reconstruction Loss):
通常使用均方误差(MSE)或交叉熵作为具体形式。
-
KL散度损失(KL Divergence Loss):
对于高斯分布,可以计算解析解:
最终的VAE损失函数为:
其中,$ \beta $ 是一个超参数,用于控制重构损失和KL散度损失之间的权衡。
训练阶段
-
编码:
- 输入数据 $ \mathbf{x} $ 通过编码器网络,输出潜在变量的参数 $ \mu(\mathbf{x}) $ 和 $ \log \sigma^2(\mathbf{x}) $。
-
重参数化技巧(Reparameterization Trick):
- 为了实现反向传播,使用重参数化技巧从 $ q(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\mu(\mathbf{x}), \sigma^2(\mathbf{x})\mathbf{I}) $ 中采样:
- 为了实现反向传播,使用重参数化技巧从 $ q(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\mu(\mathbf{x}), \sigma^2(\mathbf{x})\mathbf{I}) $ 中采样:
-
解码:
- 潜在变量 $ \mathbf{z} $ 通过解码器网络生成重构数据 $ \mathbf{x}’ $。
-
损失计算与优化:
- 计算重构损失和KL散度损失,优化整个网络以最小化总损失。
Diffusers 代码解析
flux代码中使用的vae是属于diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL类,因此来阅读对应的代码,具体的模型代码就不管了,主要关注这个使用的流程。
encode代码:调用_encode(x)得到z的均值和方差预测h = [mean, logvar],然后传入DiagonalGaussianDistribution类准备变量采样。
1 |
|
_encode函数:接收image作为输入,然后输出刚才的[mean, logvar]。
1 |
|
DiagonalGaussianDistribution类代码: 传入预测的[mean, logvar],可以调用sample采样得到重参数化的z。kl和nll应该是训练用的loss。
1 |
|
decode:用上面采样得到的z重建image x。
1 |
|
_decode:用采样得到的z重建image x的具体代码。
1 |
|
DecoderOutput:就是一个存放输出的sample的数据类。
1 |
|
1. 数据类(@dataclass
)的优势
1.1 提高代码的可读性和可维护性
结构化的数据表示:
-
使用
@dataclass
定义的类明确地展示了数据的结构和组成部分。每个字段都有明确的名称和类型,这使得代码更加自解释,易于理解。1
2
3
4@dataclass
class DecoderOutput:
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None相比之下,使用元组或字典时,字段的意义可能不那么直观:
1
2
3
4
5# 使用元组
return sample, commit_loss
# 使用字典
return {"sample": sample, "commit_loss": commit_loss}
1.2 类型检查和静态分析支持
类型提示:
-
数据类允许为每个字段指定类型,这有助于静态类型检查工具(如 MyPy)在编译时捕捉类型错误,提升代码的可靠性。
1
2
3
4
5
6
7
8from typing import Optional
import torch
from dataclasses import dataclass
@dataclass
class DecoderOutput:
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None使用元组或字典时,类型信息不够明确,可能导致类型错误更难以检测。
1.3 自动生成的特殊方法
自动生成方法:
-
@dataclass
自动为类生成常用的特殊方法,如__init__
,__repr__
,__eq__
等。这减少了样板代码,提高了开发效率。1
2
3
4@dataclass
class DecoderOutput:
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None上述代码自动生成了以下方法:
1
2
3
4
5
6
7
8
9
10
11def __init__(self, sample: torch.Tensor, commit_loss: Optional[torch.FloatTensor] = None):
self.sample = sample
self.commit_loss = commit_loss
def __repr__(self):
return f"DecoderOutput(sample={self.sample}, commit_loss={self.commit_loss})"
def __eq__(self, other):
if not isinstance(other, DecoderOutput):
return False
return self.sample == other.sample and self.commit_loss == other.commit_loss使用元组或字典,需要手动实现这些方法(如果需要),增加了代码复杂性。
1.4 更好的文档和代码提示
文档字符串和 IDE 支持:
-
数据类可以包含文档字符串(docstrings),为类和其字段提供详细说明。这对开发者来说非常有帮助,尤其是在使用 IDE 时,能够显示详细的提示信息。
1
2
3
4
5
6
7
8
9
10
11@dataclass
class DecoderOutput:
"""
Output of decoding method.
Args:
sample (torch.Tensor): The decoded output sample from the last layer of the model.
commit_loss (Optional[torch.FloatTensor]): Additional loss for committing to the latent space.
"""
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None使用元组或字典时,缺乏这种内置的文档支持,开发者需要依赖外部文档或代码注释。
1.5 错误减少和一致性提升
避免字段顺序错误:
-
使用数据类时,字段是通过名称访问的,减少了由于字段顺序错误导致的 bug。相比之下,使用元组时,必须按正确的顺序访问元素,容易出错。
1
2
3output = DecoderOutput(sample=decoded_sample, commit_loss=loss)
print(output.sample)
print(output.commit_loss)使用元组:
1
2
3output = (decoded_sample, loss)
print(output[0]) # sample
print(output[1]) # commit_loss如果不小心交换了顺序,可能会导致难以发现的错误。
1.6 易于扩展和修改
方便的字段添加和修改:
-
当需要向输出中添加新字段时,数据类只需在类定义中添加新字段,其他部分代码无需大幅修改。使用元组或字典时,可能需要修改多个代码位置来适应新字段。
1
2
3
4
5@dataclass
class DecoderOutput:
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None
additional_info: Optional[dict] = None使用元组时:
1
2
3
4
5# 原先返回
return sample, commit_loss
# 修改后需要返回更多元素
return sample, commit_loss, additional_info这会影响所有调用该函数的地方,增加了维护成本。
1.7 更好的错误检测
属性访问:
-
使用数据类时,通过属性访问字段,可以在编译时或静态分析时捕捉到不存在的属性访问错误。而使用元组或字典时,这类错误可能只在运行时才会被发现。
1
2
3
4
5
6
7
8
9
10
11
12
13
14# 数据类
output = DecoderOutput(sample, commit_loss)
print(output.sample) # 正确
print(output.non_existent) # AttributeError
# 元组
output = (sample, commit_loss)
print(output[0]) # 正确
print(output[2]) # IndexError
# 字典
output = {"sample": sample, "commit_loss": commit_loss}
print(output["sample"]) # 正确
print(output["non_existent"]) # KeyError
1.8 简化调试和日志记录
清晰的输出:
-
数据类提供了清晰的
__repr__
方法,使得在调试和日志记录时,输出信息更具可读性和可解释性。1
2
3
4
5
6
7
8@dataclass
class DecoderOutput:
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None
output = DecoderOutput(sample, commit_loss)
print(output)
# 输出: DecoderOutput(sample=tensor([...]), commit_loss=tensor([...]))使用元组或字典时,输出可能不够直观:
1
2
3
4
5
6
7output = (sample, commit_loss)
print(output)
# 输出: (tensor([...]), tensor([...]))
output = {"sample": sample, "commit_loss": commit_loss}
print(output)
# 输出: {'sample': tensor([...]), 'commit_loss': tensor([...])}
1.9 更好的集成和扩展性
与其他类和方法的集成:
-
数据类可以作为更复杂数据结构的一部分,便于与其他类和方法集成。例如,可以在数据类中嵌套其他数据类,形成更复杂的层次结构。
1
2
3
4@dataclass
class ModelOutput:
decoder_output: DecoderOutput
additional_info: Optional[dict] = None
1.10 支持不可变性(可选)
不可变数据类:
-
通过设置
frozen=True
,可以创建不可变的数据类,这在某些情况下有助于防止数据被意外修改,提升代码的安全性和稳定性。1
2
3
4@dataclass(frozen=True)
class DecoderOutput:
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None这样,一旦实例化,就无法修改其字段:
1
2output = DecoderOutput(sample, commit_loss)
output.sample = torch.randn(64, 784) # Raises FrozenInstanceError