Day 13: 图像分类与 Vision Transformer (ViT)
摘要:2020年,一张名为 “An Image is Worth 16x16 Words” 的论文让计算机视觉圈炸开了锅。Vision Transformer (ViT) 证明了不依赖卷积,纯 Transformer 也能在图像分类上取得 SOTA 效果。本文将深入拆解 ViT 及其进化版 Swin Transformer,并介绍 MixUp、CutMix 等现代数据增强技术。
1. Vision Transformer (ViT)
在 ViT 之前,CV 是 CNN 的天下(ResNet, EfficientNet)。ViT 的核心思想是:把图片当成一种特殊的“语言”序列。
1.1 “An Image is Worth 16x16 Words”
Transformer 需要序列作为输入。ViT 的做法简单粗暴:
- 切块 (Patch Partition):把一张224 × 224 224 \times 224224×224的图片,切成16 × 16 16 \times 1616×16大小的小方块。
- 总共有( 224 / 16 ) × ( 224 / 16 ) = 14 × 14 = 196 (224/16) \times (224/16) = 14 \times 14 = 196(224/16)×(224/16)=14×14=196个方块(Patch)。
- 拉平 (Flatten):每个方块展平成一个向量。
- 线性映射 (Linear Projection):用一个全连接层把这些向量映射到 Transformer 的维度(如 768)。
- 位置编码 (Positional Embedding):给这 196 个向量加上位置信息(0, 1, 2…)。
- Class Token:在序列最前面加一个特殊的可学习向量
[CLS]。- Transformer 输出时,只取这个
[CLS]对应的输出向量去做分类。
- Transformer 输出时,只取这个
1.2 ViT vs CNN
- 归纳偏置 (Inductive Bias):
- CNN:天生假设“局部性”(像素和周围相关)和“平移等变性”(猫在左上和右下是一样的)。这也是 CNN 训练快的原因。
- ViT:没有这些假设。它一开始什么都不知道,必须通过海量数据(如 JFT-300M)去自己学习像素之间的关系。
- 结论:数据量小用 CNN,数据量超级大用 ViT。
2. Swin Transformer
ViT 虽然强,但有两个缺陷:
- 计算量大:Global Attention 的复杂度是序列长度的平方O ( N 2 ) O(N^2)O(N2)。图片分辨率一高(如 1000x1000),Token 数爆炸,显存直接撑爆。
- 缺乏多尺度:ViT 直上直下,没有像 CNN 那样“先看细节,再看整体”的层次感。
Swin Transformer(Hierarchical Vision Transformer) 借鉴了 CNN 的思想:
2.1 窗口注意力 (Window Attention)
- ViT 的痛点:像开全员大会。1000 个人(像素)每个人都要和另外 999 个人握手。效率极低。
- Swin 的策略:像分小组讨论。把图划分成很多个 7x7 的小窗口。Attention 只在窗口内部算。
- 效果:复杂度从O ( N 2 ) O(N^2)O(N2)降到了O ( N ) O(N)O(N)。
2.2 移动窗口 (Shifted Window)
- 问题:小组之间没有交流,信息闭塞。
- 策略(换座位):
- Layer 1:按正常分组。
- Layer 2:移动分组界线(Shift)。比如把原来 A 组的一半人和 B 组的一半人拼成新组。
- 效果:多搞几轮,每个人都能间接地和所有人交流了。既保留了局部的高效,又实现了全局信息的流通。
3. 现代数据增强 (Data Augmentation)
ViT 这种“数据饥渴”的模型,非常依赖强力的数据增强来防止过拟合。
3.1 MixUp
简单粗暴地把两张图按比例混合。
- Input:λ × Cat + ( 1 − λ ) × Dog \lambda \times \text{Cat} + (1-\lambda) \times \text{Dog}λ×Cat+(1−λ)×Dog
- Label:λ × [ 1 , 0 ] + ( 1 − λ ) × [ 0 , 1 ] \lambda \times [1, 0] + (1-\lambda) \times [0, 1]λ×[1,0]+(1−λ)×[0,1]
- 为什么这么做?
- 现实中虽然没有“半猫半狗”,但这种训练能平滑决策边界。
- 它强迫模型理解“特征变了一点点,结果也应该只变一点点”,而不是非黑即白。这让模型更稳健,不易过拟合。
- 必要性:ViT 缺乏归纳偏置,极易过拟合,所以比 CNN 更依赖这种强力数据增强。
3.2 CutMix
剪切粘贴。
- 在图片 A 上随机挖个框,把图片 B 的对应区域填进去。
- 标签也按面积比例混合。
- 比 MixUp 更自然,因为像素值没有失真,保留了局部纹理。
4. 知识蒸馏 (Knowledge Distillation)
在分类任务中,我们常想把大模型(Teacher)的能力传给小模型(Student)。
- Hard Label: 真实标签(One-hot)。
- Soft Label: Teacher 输出的概率分布(比如:猫0.8,狗0.15,车0.05)。
- 原理:Soft Label 包含了暗知识(Dark Knowledge)。比如 Teacher 认为“这张图虽然是猫,但长得有点像狗”。这种信息能帮 Student 更好地泛化。
5. 代码实践:MixUp 实现
importtorchimportnumpyasnpdefmixup_data(x,y,alpha=1.0):'''Returns mixed inputs, pairs of targets, and lambda'''ifalpha>0:# Beta分布采样 lambdalam=np.random.beta(alpha,alpha)else:lam=1batch_size=x.size(0)# 随机生成乱序索引index=torch.randperm(batch_size).to(x.device)# 混合 Inputmixed_x=lam*x+(1-lam)*x[index]# 返回: 混合后的图, 原始标签,乱序标签, 混合比例returnmixed_x,y,y[index],lamdefmixup_criterion(criterion,pred,y_a,y_b,lam):'''计算混合后的 Loss'''returnlam*criterion(pred,y_a)+(1-lam)*criterion(pred,y_b)# 训练循环中使用# inputs, targets = data# inputs, targets_a, targets_b, lam = mixup_data(inputs, targets)# outputs = model(inputs)# loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)# loss.backward()6. 总结
- ViT打破了 CV 和 NLP 的壁垒,证明了 Transformer 的通用性。
- Swin Transformer引入了 CNN 的层次化和局部性设计,成为了目前 CV 任务的通用骨干。
- MixUp/CutMix是训练高性能 ViT 的必备良药。
掌握了这些,你就不再局限于 ResNet,而是进入了 CV 的 Transformer 时代。
参考资料
- An Image is Worth 16x16 Words (ViT Paper)
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows