Clip Loss
CLIP(Contrastive Language-Image Pretraining)使用了**对比学习(Contrastive Learning)**的方式来训练图像和文本的匹配关系,其核心思想是让正确的图像-文本对在特征空间中靠近,而错误的对则远离。本文将详细解析 CLIP 的损失函数的原理、推导,并提供 PyTorch 代码实现。
1. CLIP 的损失函数原理
CLIP 的损失函数基于 InfoNCE(信息噪声对比估计),也可以看作是 对比损失(Contrastive Loss)。其核心思想如下:
• 给定一个批次的 N 组(image, text)对,即 $$ {(I_1, T_1), (I_2, T_2), …, (I_N, T_N)} $$
• 通过 图像编码器(Vision Transformer 或 ResNet)将图像 I_i 映射到特征向量 v_i
• 通过 文本编码器(Transformer)将文本 T_i 映射到特征向量 t_i
• 通过归一化将这些向量映射到单位球面
• 计算所有 图像-文本对的相似度矩阵
• 使用 对比损失(Contrastive Loss) 使得正确的 (image, text) 对的相似度最大化,错误的最小化
1.1 计算相似度
设 v_i 是图像的嵌入(embedding),t_j 是文本的嵌入。CLIP 计算 图像-文本对之间的相似度 :
即归一化的余弦相似度,表示图像 i 和文本 j 之间的匹配程度。
1.2 计算损失函数
CLIP 采用双向对比损失,即:
- 图像作为查询(image-to-text):希望每张图像 I_i 在所有文本 T_j 中,只与它的匹配文本 T_i 有最高的相似度。
- 文本作为查询(text-to-image):希望每个文本 T_i 在所有图像 I_j 中,只与它的匹配图像 I_i 有最高的相似度。
采用 交叉熵损失(Cross-Entropy Loss),分别对行和列计算损失:
(1) Image-to-Text(I→T)损失
对于每个图像 I_i,正确的文本是 T_i,计算 Softmax 归一化概率:
其中:
• $$ \tau $$ 是温度参数(temperature parameter),用于调整分布的陡峭程度。
(2) Text-to-Image(T→I)损失
类似地,对于每个文本 T_j,正确的图像是 I_j:
(3) 总损失
最终损失是两个方向的平均:
2. PyTorch 代码实现 CLIP 损失函数
下面是 CLIP 的 对比损失(Contrastive Loss) 在 PyTorch 中的实现。
2.1 CLIP 主要步骤
• 编码图像和文本(使用 Transformer 和 ViT)
• 归一化特征
• 计算余弦相似度
• 计算 InfoNCE 损失(交叉熵)
import torch
import torch.nn.functional as F
class CLIPLoss(torch.nn.Module):
def __init__(self, temperature=0.07):
super(CLIPLoss, self).__init__()
self.temperature = temperature
def forward(self, image_features, text_features):
"""
计算 CLIP 的对比损失
参数:
- image_features: 形状为 (N, D) 的图像特征
- text_features: 形状为 (N, D) 的文本特征
返回:
- CLIP 损失
"""
# 归一化特征(L2 归一化到单位球面)
image_features = F.normalize(image_features, p=2, dim=-1) # (N, D)
text_features = F.normalize(text_features, p=2, dim=-1) # (N, D)
# 计算相似度矩阵(余弦相似度)
logits = (image_features @ text_features.T) / self.temperature # (N, N)
# 目标标签:对角线上的才是匹配的
labels = torch.arange(logits.shape[0]).to(logits.device) # (N,)
# 计算交叉熵损失(分别计算 I->T 和 T->I)
loss_i2t = F.cross_entropy(logits, labels) # 图像->文本
loss_t2i = F.cross_entropy(logits.T, labels) # 文本->图像
# 取平均
loss = (loss_i2t + loss_t2i) / 2
return loss
3. 代码解析
3.1 归一化
image_features = F.normalize(image_features, p=2, dim=-1)
text_features = F.normalize(text_features, p=2, dim=-1)
将图像和文本特征归一化到单位球面,确保计算的相似度是余弦相似度。
3.2 计算相似度
logits = (image_features @ text_features.T) / self.temperature
使用矩阵乘法计算图像-文本相似度,除以温度参数 $$ \tau $$。
3.3 计算损失
labels = torch.arange(logits.shape[0]).to(logits.device)
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.T, labels)
loss = (loss_i2t + loss_t2i) / 2
• labels = torch.arange(N)
: 真实匹配的文本索引
• cross_entropy(logits, labels)
: 计算交叉熵损失
• 计算 image-to-text 和 text-to-image,并取均值。
4. 结论
CLIP 的损失函数是对比损失的一种实现方式,利用对称交叉熵损失来对图像和文本进行匹配训练。PyTorch 代码实现主要包括:
- 计算特征的归一化
- 计算相似度矩阵
- 计算双向交叉熵损失
- 取均值得到最终损失
logits 矩阵 S ** 是一个方阵**,其形状为(N, N),对角线元素 S_{i,i} 代表匹配的图像-文本对的相似度。既然 S 是方阵,那么它的转置 S^T 也是方阵,对角线元素不变。那么,为什么 CLIP 还要分成 两个损失(image-to-text 和 text-to-image)呢?
1. 为什么要分 Image-to-Text 和 Text-to-Image 损失?
尽管 logits 矩阵是方阵,但它的行和列的含义不同:
• 行视角(image-to-text, I→T):
每一行 i 代表图像 I_i 与所有文本 T_j 的相似度。我们希望图像 I_i 和正确的文本 T_i 的相似度最大,其他文本的相似度最小。
损失公式:
这里,我们在每一行做 softmax,把每张图像看成 query,所有文本作为候选。
• 列视角(text-to-image, T→I):
每一列 j 代表文本 T_j 与所有图像 I_i 的相似度。我们希望文本 T_j 和正确的图像 I_j 的相似度最大,其他图像的相似度最小。
损失公式:
这里,我们在每一列做 softmax,把每个文本看成 query,所有图像作为候选。
核心区别:Softmax 计算方式不同
• $$ \mathcal{L}_{I \rightarrow T} $$ 计算的是每行的 softmax(图像为 query)
• $$ \mathcal{L}_{T \rightarrow I} $$ 计算的是每列的 softmax(文本为 query)
• 尽管对角线元素数值相同,但 softmax 归一化时,分母不同,导致两个方向的 loss 并不等价。
2. 直观理解
假设有 3 组 (image, text) 对:
• 行视角(I→T):
• 以 第一行为例,我们计算 图像 I_1 与所有文本 T_j 的 softmax:
$$ P(T_1 | I_1) = \frac{\exp(s_{1,1} / \tau)}{\exp(s_{1,1} / \tau) + \exp(s_{1,2} / \tau) + \exp(s_{1,3} / \tau)} $$
• 这样每一行的 softmax 归一化,使得图像 embedding 学到的是更好地匹配文本。
• 列视角(T→I):
• 以 第一列为例,我们计算 文本 T_1 与所有图像 I_i 的 softmax:
$$ P(I_1 | T_1) = \frac{\exp(s_{1,1} / \tau)}{\exp(s_{1,1} / \tau) + \exp(s_{2,1} / \tau) + \exp(s_{3,1} / \tau)} $$
• 这样每一列的 softmax 归一化,使得文本 embedding 学到的是更好地匹配图像。
结论:即便 S 是对称的,行 softmax 和列 softmax 计算方式不同,导致优化目标不同,必须分别计算。
3. PyTorch 代码解析
loss_i2t = F.cross_entropy(logits, labels) # 图像->文本
loss_t2i = F.cross_entropy(logits.T, labels) # 文本->图像
• loss_i2t
: 计算行 softmax,图像作为 query
• loss_t2i
: 计算列 softmax,文本作为 query
• 两个方向都优化,确保图像和文本 embedding 都足够强
4. 总结
为什么 CLIP 需要分两个 loss?
-
logits 矩阵 S 是方阵,但 softmax 计算方式不同
• I→T 计算每行的 softmax(图像为 query)
• T→I 计算每列的 softmax(文本为 query)
• 这导致两个方向的 loss 不等价。 -
优化目标不同
•loss_i2t
让图像更容易找到正确的文本。
•loss_t2i
让文本更容易找到正确的图像。 -
双向优化更鲁棒
• 如果只优化 I→T,那么可能文本的表示能力较差,反之亦然。
• 双向优化保证模型学习到更均衡的表征,提高检索和匹配能力。
因此,CLIP 采用了双向损失(image-to-text + text-to-image),确保图像和文本的嵌入在语义空间中对齐得更好。