Clip Loss

CLIP(Contrastive Language-Image Pretraining)使用了**对比学习(Contrastive Learning)**的方式来训练图像和文本的匹配关系,其核心思想是让正确的图像-文本对在特征空间中靠近,而错误的对则远离。本文将详细解析 CLIP 的损失函数的原理、推导,并提供 PyTorch 代码实现。


1. CLIP 的损失函数原理

CLIP 的损失函数基于 InfoNCE(信息噪声对比估计),也可以看作是 对比损失(Contrastive Loss)。其核心思想如下:

• 给定一个批次的 N 组(image, text)对,即 $$ {(I_1, T_1), (I_2, T_2), …, (I_N, T_N)} $$
• 通过 图像编码器(Vision Transformer 或 ResNet)将图像 I_i 映射到特征向量 v_i
• 通过 文本编码器(Transformer)将文本 T_i 映射到特征向量 t_i
• 通过归一化将这些向量映射到单位球面
• 计算所有 图像-文本对的相似度矩阵
• 使用 对比损失(Contrastive Loss) 使得正确的 (image, text) 对的相似度最大化,错误的最小化

1.1 计算相似度

设 v_i 是图像的嵌入(embedding),t_j 是文本的嵌入。CLIP 计算 图像-文本对之间的相似度

si,j=vitjvitjs_{i, j} = \frac{v_i \cdot t_j}{\|v_i\| \|t_j\|}

归一化的余弦相似度,表示图像 i 和文本 j 之间的匹配程度。

1.2 计算损失函数

CLIP 采用双向对比损失,即:

  1. 图像作为查询(image-to-text):希望每张图像 I_i 在所有文本 T_j 中,只与它的匹配文本 T_i 有最高的相似度。
  2. 文本作为查询(text-to-image):希望每个文本 T_i 在所有图像 I_j 中,只与它的匹配图像 I_i 有最高的相似度。

采用 交叉熵损失(Cross-Entropy Loss),分别对行和列计算损失:

(1) Image-to-Text(I→T)损失

对于每个图像 I_i,正确的文本是 T_i,计算 Softmax 归一化概率:

LIT=1Ni=1Nlogexp(si,i/τ)j=1Nexp(si,j/τ)\mathcal{L}_{I \rightarrow T} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(s_{i, i} / \tau)}{\sum_{j=1}^{N} \exp(s_{i, j} / \tau)}

其中:
• $$ \tau $$ 是温度参数(temperature parameter),用于调整分布的陡峭程度。

(2) Text-to-Image(T→I)损失

类似地,对于每个文本 T_j,正确的图像是 I_j:

LTI=1Nj=1Nlogexp(sj,j/τ)i=1Nexp(si,j/τ)\mathcal{L}_{T \rightarrow I} = -\frac{1}{N} \sum_{j=1}^{N} \log \frac{\exp(s_{j, j} / \tau)}{\sum_{i=1}^{N} \exp(s_{i, j} / \tau)}

(3) 总损失

最终损失是两个方向的平均:

L=12(LIT+LTI)\mathcal{L} = \frac{1}{2} (\mathcal{L}_{I \rightarrow T} + \mathcal{L}_{T \rightarrow I})


2. PyTorch 代码实现 CLIP 损失函数

下面是 CLIP 的 对比损失(Contrastive Loss) 在 PyTorch 中的实现。

2.1 CLIP 主要步骤

编码图像和文本(使用 Transformer 和 ViT)
归一化特征
计算余弦相似度
计算 InfoNCE 损失(交叉熵)

import torch  
import torch.nn.functional as F  

class CLIPLoss(torch.nn.Module):  
    def __init__(self, temperature=0.07):  
        super(CLIPLoss, self).__init__()  
        self.temperature = temperature  

    def forward(self, image_features, text_features):  
        """  
        计算 CLIP 的对比损失  
        
        参数:  
        - image_features: 形状为 (N, D) 的图像特征  
        - text_features: 形状为 (N, D) 的文本特征  

        返回:  
        - CLIP 损失  
        """  

        # 归一化特征(L2 归一化到单位球面)  
        image_features = F.normalize(image_features, p=2, dim=-1)  # (N, D)  
        text_features = F.normalize(text_features, p=2, dim=-1)  # (N, D)  

        # 计算相似度矩阵(余弦相似度)  
        logits = (image_features @ text_features.T) / self.temperature  # (N, N)  

        # 目标标签:对角线上的才是匹配的  
        labels = torch.arange(logits.shape[0]).to(logits.device)  # (N,)  

        # 计算交叉熵损失(分别计算 I->T 和 T->I)  
        loss_i2t = F.cross_entropy(logits, labels)  # 图像->文本  
        loss_t2i = F.cross_entropy(logits.T, labels)  # 文本->图像  

        # 取平均  
        loss = (loss_i2t + loss_t2i) / 2  

        return loss  

3. 代码解析

3.1 归一化

image_features = F.normalize(image_features, p=2, dim=-1)  
text_features = F.normalize(text_features, p=2, dim=-1)  

将图像和文本特征归一化到单位球面,确保计算的相似度是余弦相似度

3.2 计算相似度

logits = (image_features @ text_features.T) / self.temperature  

使用矩阵乘法计算图像-文本相似度,除以温度参数 $$ \tau $$。

3.3 计算损失

labels = torch.arange(logits.shape[0]).to(logits.device)  
loss_i2t = F.cross_entropy(logits, labels)  
loss_t2i = F.cross_entropy(logits.T, labels)  
loss = (loss_i2t + loss_t2i) / 2  

labels = torch.arange(N): 真实匹配的文本索引
cross_entropy(logits, labels): 计算交叉熵损失
• 计算 image-to-texttext-to-image,并取均值。


4. 结论

CLIP 的损失函数是对比损失的一种实现方式,利用对称交叉熵损失来对图像和文本进行匹配训练。PyTorch 代码实现主要包括:

  1. 计算特征的归一化
  2. 计算相似度矩阵
  3. 计算双向交叉熵损失
  4. 取均值得到最终损失

logits 矩阵 S ** 是一个方阵**,其形状为(N, N),对角线元素 S_{i,i} 代表匹配的图像-文本对的相似度。既然 S 是方阵,那么它的转置 S^T 也是方阵,对角线元素不变。那么,为什么 CLIP 还要分成 两个损失(image-to-text 和 text-to-image)呢?


1. 为什么要分 Image-to-Text 和 Text-to-Image 损失?

尽管 logits 矩阵是方阵,但它的行和列的含义不同:
行视角(image-to-text, I→T):
每一行 i 代表图像 I_i 与所有文本 T_j 的相似度。我们希望图像 I_i 和正确的文本 T_i 的相似度最大,其他文本的相似度最小。

损失公式:

LIT=1Ni=1Nlogexp(Si,i/τ)j=1Nexp(Si,j/τ)\mathcal{L}_{I \rightarrow T} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(S_{i, i} / \tau)}{\sum_{j=1}^{N} \exp(S_{i, j} / \tau)}

这里,我们在每一行做 softmax,把每张图像看成 query,所有文本作为候选

列视角(text-to-image, T→I):
每一列 j 代表文本 T_j 与所有图像 I_i 的相似度。我们希望文本 T_j 和正确的图像 I_j 的相似度最大,其他图像的相似度最小。

损失公式:

LTI=1Nj=1Nlogexp(Sj,j/τ)i=1Nexp(Si,j/τ)\mathcal{L}_{T \rightarrow I} = -\frac{1}{N} \sum_{j=1}^{N} \log \frac{\exp(S_{j, j} / \tau)}{\sum_{i=1}^{N} \exp(S_{i, j} / \tau)}

这里,我们在每一列做 softmax,把每个文本看成 query,所有图像作为候选

核心区别:Softmax 计算方式不同

$$ \mathcal{L}_{I \rightarrow T} $$ 计算的是每行的 softmax(图像为 query)
$$ \mathcal{L}_{T \rightarrow I} $$ 计算的是每列的 softmax(文本为 query)
尽管对角线元素数值相同,但 softmax 归一化时,分母不同,导致两个方向的 loss 并不等价。


2. 直观理解

假设有 3 组 (image, text) 对:

S=[s1,1s1,2s1,3s2,1s2,2s2,3s3,1s3,2s3,3] S = \begin{bmatrix} s_{1,1} & s_{1,2} & s_{1,3} \\ s_{2,1} & s_{2,2} & s_{2,3} \\ s_{3,1} & s_{3,2} & s_{3,3} \end{bmatrix}

行视角(I→T):
• 以 第一行为例,我们计算 图像 I_1 与所有文本 T_j 的 softmax
$$ P(T_1 | I_1) = \frac{\exp(s_{1,1} / \tau)}{\exp(s_{1,1} / \tau) + \exp(s_{1,2} / \tau) + \exp(s_{1,3} / \tau)} $$
• 这样每一行的 softmax 归一化,使得图像 embedding 学到的是更好地匹配文本

列视角(T→I):
• 以 第一列为例,我们计算 文本 T_1 与所有图像 I_i 的 softmax
$$ P(I_1 | T_1) = \frac{\exp(s_{1,1} / \tau)}{\exp(s_{1,1} / \tau) + \exp(s_{2,1} / \tau) + \exp(s_{3,1} / \tau)} $$
• 这样每一列的 softmax 归一化,使得文本 embedding 学到的是更好地匹配图像

结论:即便 S 是对称的,行 softmax 和列 softmax 计算方式不同,导致优化目标不同,必须分别计算。


3. PyTorch 代码解析

loss_i2t = F.cross_entropy(logits, labels)  # 图像->文本  
loss_t2i = F.cross_entropy(logits.T, labels)  # 文本->图像  

loss_i2t: 计算行 softmax,图像作为 query
loss_t2i: 计算列 softmax,文本作为 query
两个方向都优化,确保图像和文本 embedding 都足够强


4. 总结

为什么 CLIP 需要分两个 loss?

  1. logits 矩阵 S 是方阵,但 softmax 计算方式不同
    I→T 计算每行的 softmax(图像为 query)
    T→I 计算每列的 softmax(文本为 query)
    • 这导致两个方向的 loss 不等价。

  2. 优化目标不同
    loss_i2t 让图像更容易找到正确的文本。
    loss_t2i 让文本更容易找到正确的图像。

  3. 双向优化更鲁棒
    • 如果只优化 I→T,那么可能文本的表示能力较差,反之亦然。
    双向优化保证模型学习到更均衡的表征,提高检索和匹配能力。

因此,CLIP 采用了双向损失(image-to-text + text-to-image),确保图像和文本的嵌入在语义空间中对齐得更好。


Clip Loss
https://deadsmither5.github.io/2025/03/09/Loss/
作者
zhaoxing
发布于
2025年3月9日
许可协议