vae代码阅读

在阅读flux代码(pipeline_flux.Fluxpipeline类中的__call__方法最后几句)对diffusion采样得到的latent转换回vae输入的这段代码时,看到标蓝这段代码很疑惑,因此打算复习一下vae,研究一下为什么要乘以缩放系数以及进行偏置。

查阅网上博客后,有人说是因为pixel space变成latent space之后的值都特别大,因此需要一个缩放因子来让范围变小,同时进行偏置使得范围合理。 这是我认为比较正确的理解。

官方文档的解释:https://huggingface.co/docs/diffusers/main/en/api/models/autoencoderkl

1
2
3
4
5
6
7
8
if output_type == "latent":
image = latents

else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
'latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor'
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

原理解析

两个主要部分组成:

  • 编码器(Encoder):将输入数据 $ x $ 映射到潜在空间 $ \mathbf{z} $。
  • 解码器(Decoder):将潜在表示 $ \mathbf{z} $ 重构回原始数据空间 $ \mathbf{x}’ $。

VAE将数据生成过程建模为一个概率过程:

  1. 潜在变量的先验分布:假设潜在变量 $ \mathbf{z} $ 服从某个先验分布,通常选择标准正态分布:

    p(z)=N(0,I)p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I})

  2. 生成模型:给定潜在变量 $ \mathbf{z} $,生成数据 $ \mathbf{x} $ 的条件分布:

    p(xz)=N(μ(z),σ2(z)I)p(\mathbf{x}|\mathbf{z}) = \mathcal{N}(\mu(\mathbf{z}), \sigma^2(\mathbf{z})\mathbf{I})

    其中,$ \mu(\mathbf{z}) $ 和 $ \sigma(\mathbf{z}) $ 由解码器网络参数化。

直接计算后验分布 $ p(\mathbf{z}|\mathbf{x}) $ 通常非常困难,因此VAE使用变分推断,通过引入一个可参数化的近似分布 $ q(\mathbf{z}|\mathbf{x}) $ 来逼近真实的后验分布。

为了训练模型,VAE最大化证据下界(Evidence Lower Bound, ELBO):

logp(x)Eq(zx)[logp(xz)]KL(q(zx)p(z))\log p(\mathbf{x}) \geq \mathbb{E}_{q(\mathbf{z}|\mathbf{x})} [\log p(\mathbf{x}|\mathbf{z})] - \text{KL}(q(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))

其中:

  • 重构项Eq(zx)[logp(xz)]\mathbb{E}_{q(\mathbf{z}|\mathbf{x})} [\log p(\mathbf{x}|\mathbf{z})]

    • 衡量模型重构数据的能力。
  • 正则化项KL(q(zx)p(z))\text{KL}(q(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))

    • 衡量近似后验分布与先验分布之间的差异,确保潜在空间的连续性和规则性。

VAE的损失函数由两个部分组成:

  1. 重构损失(Reconstruction Loss)

    Lrecon=Eq(zx)[logp(xz)]\mathcal{L}_{\text{recon}} = -\mathbb{E}_{q(\mathbf{z}|\mathbf{x})} [\log p(\mathbf{x}|\mathbf{z})]

    通常使用均方误差(MSE)或交叉熵作为具体形式。

  2. KL散度损失(KL Divergence Loss)

    LKL=KL(q(zx)p(z))\mathcal{L}_{\text{KL}} = \text{KL}(q(\mathbf{z}|\mathbf{x}) || p(\mathbf{z}))

    对于高斯分布,可以计算解析解:

    KL(N(μ,σ2)N(0,1))=12i=1d(1+log(σi2)μi2σi2)\text{KL}(\mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1)) = -\frac{1}{2} \sum_{i=1}^{d} \left(1 + \log(\sigma_i^2) - \mu_i^2 - \sigma_i^2\right)

最终的VAE损失函数为:

L=Lrecon+βLKL\mathcal{L} = \mathcal{L}_{\text{recon}} + \beta \mathcal{L}_{\text{KL}}

其中,$ \beta $ 是一个超参数,用于控制重构损失和KL散度损失之间的权衡。

训练阶段

  1. 编码

    • 输入数据 $ \mathbf{x} $ 通过编码器网络,输出潜在变量的参数 $ \mu(\mathbf{x}) $ 和 $ \log \sigma^2(\mathbf{x}) $。
  2. 重参数化技巧(Reparameterization Trick)

    • 为了实现反向传播,使用重参数化技巧从 $ q(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\mu(\mathbf{x}), \sigma^2(\mathbf{x})\mathbf{I}) $ 中采样:

      z=μ(x)+σ(x)ϵ,ϵN(0,I)\mathbf{z} = \mu(\mathbf{x}) + \sigma(\mathbf{x}) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I})

  3. 解码

    • 潜在变量 $ \mathbf{z} $ 通过解码器网络生成重构数据 $ \mathbf{x}’ $。
  4. 损失计算与优化

    • 计算重构损失和KL散度损失,优化整个网络以最小化总损失。

Diffusers 代码解析

flux代码中使用的vae是属于diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL类,因此来阅读对应的代码,具体的模型代码就不管了,主要关注这个使用的流程。

encode代码:调用_encode(x)得到z的均值和方差预测h = [mean, logvar],然后传入DiagonalGaussianDistribution类准备变量采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.

Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)

posterior = DiagonalGaussianDistribution(h)

if not return_dict:
return (posterior,)

return AutoencoderKLOutput(latent_dist=posterior)

_encode函数:接收image作为输入,然后输出刚才的[mean, logvar]。

1
2
3
4
5
6
7
8
9
10
11
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape

if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)

enc = self.encoder(x)
if self.quant_conv is not None:
enc = self.quant_conv(enc)

return enc

DiagonalGaussianDistribution类代码: 传入预测的[mean, logvar],可以调用sample采样得到重参数化的z。kl和nll应该是训练用的loss。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)

def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x

def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)

def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)

def mode(self) -> torch.Tensor:
return self.mean

decode:用上面采样得到的z重建image x。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.

"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample

if not return_dict:
return (decoded,)

return DecoderOutput(sample=decoded)

_decode:用采样得到的z重建image x的具体代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)

if self.post_quant_conv is not None:
z = self.post_quant_conv(z)

dec = self.decoder(z)

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

DecoderOutput:就是一个存放输出的sample的数据类。

1
2
3
4
5
6
7
8
9
10
11
12
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.

Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""

sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None

1. 数据类(@dataclass)的优势

1.1 提高代码的可读性和可维护性

结构化的数据表示

  • 使用 @dataclass 定义的类明确地展示了数据的结构和组成部分。每个字段都有明确的名称和类型,这使得代码更加自解释,易于理解。

    1
    2
    3
    4
    @dataclass
    class DecoderOutput:
    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None

    相比之下,使用元组或字典时,字段的意义可能不那么直观:

    1
    2
    3
    4
    5
    # 使用元组
    return sample, commit_loss

    # 使用字典
    return {"sample": sample, "commit_loss": commit_loss}

1.2 类型检查和静态分析支持

类型提示

  • 数据类允许为每个字段指定类型,这有助于静态类型检查工具(如 MyPy)在编译时捕捉类型错误,提升代码的可靠性。

    1
    2
    3
    4
    5
    6
    7
    8
    from typing import Optional
    import torch
    from dataclasses import dataclass

    @dataclass
    class DecoderOutput:
    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None

    使用元组或字典时,类型信息不够明确,可能导致类型错误更难以检测。

1.3 自动生成的特殊方法

自动生成方法

  • @dataclass 自动为类生成常用的特殊方法,如 __init__, __repr__, __eq__ 等。这减少了样板代码,提高了开发效率。

    1
    2
    3
    4
    @dataclass
    class DecoderOutput:
    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None

    上述代码自动生成了以下方法:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def __init__(self, sample: torch.Tensor, commit_loss: Optional[torch.FloatTensor] = None):
    self.sample = sample
    self.commit_loss = commit_loss

    def __repr__(self):
    return f"DecoderOutput(sample={self.sample}, commit_loss={self.commit_loss})"

    def __eq__(self, other):
    if not isinstance(other, DecoderOutput):
    return False
    return self.sample == other.sample and self.commit_loss == other.commit_loss

    使用元组或字典,需要手动实现这些方法(如果需要),增加了代码复杂性。

1.4 更好的文档和代码提示

文档字符串和 IDE 支持

  • 数据类可以包含文档字符串(docstrings),为类和其字段提供详细说明。这对开发者来说非常有帮助,尤其是在使用 IDE 时,能够显示详细的提示信息。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    @dataclass
    class DecoderOutput:
    """
    Output of decoding method.

    Args:
    sample (torch.Tensor): The decoded output sample from the last layer of the model.
    commit_loss (Optional[torch.FloatTensor]): Additional loss for committing to the latent space.
    """
    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None

    使用元组或字典时,缺乏这种内置的文档支持,开发者需要依赖外部文档或代码注释。

1.5 错误减少和一致性提升

避免字段顺序错误

  • 使用数据类时,字段是通过名称访问的,减少了由于字段顺序错误导致的 bug。相比之下,使用元组时,必须按正确的顺序访问元素,容易出错。

    1
    2
    3
    output = DecoderOutput(sample=decoded_sample, commit_loss=loss)
    print(output.sample)
    print(output.commit_loss)

    使用元组:

    1
    2
    3
    output = (decoded_sample, loss)
    print(output[0]) # sample
    print(output[1]) # commit_loss

    如果不小心交换了顺序,可能会导致难以发现的错误。

1.6 易于扩展和修改

方便的字段添加和修改

  • 当需要向输出中添加新字段时,数据类只需在类定义中添加新字段,其他部分代码无需大幅修改。使用元组或字典时,可能需要修改多个代码位置来适应新字段。

    1
    2
    3
    4
    5
    @dataclass
    class DecoderOutput:
    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None
    additional_info: Optional[dict] = None

    使用元组时:

    1
    2
    3
    4
    5
    # 原先返回
    return sample, commit_loss

    # 修改后需要返回更多元素
    return sample, commit_loss, additional_info

    这会影响所有调用该函数的地方,增加了维护成本。

1.7 更好的错误检测

属性访问

  • 使用数据类时,通过属性访问字段,可以在编译时或静态分析时捕捉到不存在的属性访问错误。而使用元组或字典时,这类错误可能只在运行时才会被发现。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # 数据类
    output = DecoderOutput(sample, commit_loss)
    print(output.sample) # 正确
    print(output.non_existent) # AttributeError

    # 元组
    output = (sample, commit_loss)
    print(output[0]) # 正确
    print(output[2]) # IndexError

    # 字典
    output = {"sample": sample, "commit_loss": commit_loss}
    print(output["sample"]) # 正确
    print(output["non_existent"]) # KeyError

1.8 简化调试和日志记录

清晰的输出

  • 数据类提供了清晰的 __repr__ 方法,使得在调试和日志记录时,输出信息更具可读性和可解释性。

    1
    2
    3
    4
    5
    6
    7
    8
    @dataclass
    class DecoderOutput:
    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None

    output = DecoderOutput(sample, commit_loss)
    print(output)
    # 输出: DecoderOutput(sample=tensor([...]), commit_loss=tensor([...]))

    使用元组或字典时,输出可能不够直观:

    1
    2
    3
    4
    5
    6
    7
    output = (sample, commit_loss)
    print(output)
    # 输出: (tensor([...]), tensor([...]))

    output = {"sample": sample, "commit_loss": commit_loss}
    print(output)
    # 输出: {'sample': tensor([...]), 'commit_loss': tensor([...])}

1.9 更好的集成和扩展性

与其他类和方法的集成

  • 数据类可以作为更复杂数据结构的一部分,便于与其他类和方法集成。例如,可以在数据类中嵌套其他数据类,形成更复杂的层次结构。

    1
    2
    3
    4
    @dataclass
    class ModelOutput:
    decoder_output: DecoderOutput
    additional_info: Optional[dict] = None

1.10 支持不可变性(可选)

不可变数据类

  • 通过设置 frozen=True,可以创建不可变的数据类,这在某些情况下有助于防止数据被意外修改,提升代码的安全性和稳定性。

    1
    2
    3
    4
    @dataclass(frozen=True)
    class DecoderOutput:
    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None

    这样,一旦实例化,就无法修改其字段:

    1
    2
    output = DecoderOutput(sample, commit_loss)
    output.sample = torch.randn(64, 784) # Raises FrozenInstanceError

vae代码阅读
https://deadsmither5.github.io/2024/12/28/vae代码阅读/
作者
zhaoxing
发布于
2024年12月28日
许可协议