从零构建你的第一个少样本学习数据集:用PyTorch和ImageFolder玩转miniImageNet
当你想复现一篇经典的少样本学习论文时,最令人头疼的往往不是模型实现,而是如何准备那个"标准"的数据集。本文将带你从零开始,一步步构建miniImageNet数据集,并实现支持episode采样的数据加载流程。
1. 理解miniImageNet的前世今生
miniImageNet最初由Vinyals等人在2016年的《Matching Networks for One Shot Learning》论文中提出,作为ImageNet的一个子集,专门用于少样本学习研究。它包含了100个类别,每个类别有600张图片,总共有6万张84x84大小的图像。
为什么选择miniImageNet?
- 规模适中:相比完整的ImageNet(1000类,120万张),miniImageNet更适合快速实验
- 标准划分:已被社区广泛接受,便于结果对比
- 挑战性:保留了ImageNet的多样性,适合测试模型的泛化能力
注意:不同论文中使用的miniImageNet划分可能不同,本文采用Ravi等人在《Optimization as a model for few-shot learning》中的划分方式:64类训练,16类验证,20类测试。
2. 从原始ImageNet到miniImageNet
2.1 获取原始ImageNet数据
由于ImageNet官方已不再开放公开下载,我们可以通过学术种子获取ILSVRC2012数据集:
# 训练集 (138GB) wget http://academictorrents.com/download/a306397ccf9c2ead27155983c254227c0fd938e2.torrent # 验证集 (6.3GB) wget http://academictorrents.com/download/5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5.torrent下载完成后,验证文件完整性:
md5sum ILSVRC2012_img_train.tar # 应输出: 1d675b47d978889d74fa0da5fadfb00e md5sum ILSVRC2012_img_val.tar # 应输出: 29b22e2961454d5413ddabcf34fc56222.2 解压与预处理
ImageNet的训练集是按类别分tar包存储的,我们需要先解压外层tar,再逐个解压内部tar:
import os import tarfile def extract_tar(tar_path, output_dir): with tarfile.open(tar_path) as tar: tar.extractall(path=output_dir) # 解压内部tar文件 for class_tar in os.listdir(output_dir): if class_tar.endswith('.tar'): class_name = class_tar.split('.')[0] class_dir = os.path.join(output_dir, class_name) os.makedirs(class_dir, exist_ok=True) with tarfile.open(os.path.join(output_dir, class_tar)) as tar: tar.extractall(path=class_dir) os.remove(os.path.join(output_dir, class_tar))2.3 构建miniImageNet子集
我们将使用Ravi提供的划分方案,从1000类中选出100类,并按64/16/20的比例划分训练/验证/测试集:
import pandas as pd from shutil import copyfile import cv2 def build_mini_imagenet(original_dir, output_dir, split_csv_dir, size=84): splits = ['train', 'val', 'test'] for split in splits: split_dir = os.path.join(output_dir, split) os.makedirs(split_dir, exist_ok=True) df = pd.read_csv(os.path.join(split_csv_dir, f'{split}.csv')) for _, row in df.iterrows(): src_path = find_original_path(original_dir, row['filename']) dst_dir = os.path.join(split_dir, row['label']) os.makedirs(dst_dir, exist_ok=True) if size > 0: img = cv2.imread(src_path) img = cv2.resize(img, (size, size)) cv2.imwrite(os.path.join(dst_dir, row['filename']), img) else: copyfile(src_path, os.path.join(dst_dir, row['filename']))3. 构建支持Episode采样的数据集类
3.1 自定义ImageFolder数据集
PyTorch的ImageFolder默认按文件夹顺序分配标签,我们需要自定义以保持与原始标签一致:
from torchvision.datasets import ImageFolder import numpy as np class MiniImageNetDataset(ImageFolder): def __init__(self, root, phase='train', transform=None): self.phase = phase self.class_to_idx = self._load_class_mapping() super().__init__(root=os.path.join(root, phase), transform=transform) def _load_class_mapping(self): # 加载CSV文件建立标签映射 df = pd.read_csv(f'./splits/{self.phase}.csv') unique_labels = df['label'].unique() return {label: idx for idx, label in enumerate(unique_labels)} def find_classes(self, directory): return list(self.class_to_idx.keys()), self.class_to_idx3.2 实现Episode采样器
少样本学习通常采用episode训练方式,每个episode包含N个类,每个类K个样本:
from torch.utils.data.sampler import Sampler import numpy as np class EpisodeSampler(Sampler): def __init__(self, labels, n_way, k_shot, n_query, n_episodes): self.labels = labels self.n_way = n_way self.k_shot = k_shot self.n_query = n_query self.n_episodes = n_episodes # 按类别组织样本索引 self.class_indices = {} for idx, label in enumerate(labels): if label not in self.class_indices: self.class_indices[label] = [] self.class_indices[label].append(idx) def __iter__(self): for _ in range(self.n_episodes): # 随机选择n_way个类别 selected_classes = np.random.choice( list(self.class_indices.keys()), self.n_way, replace=False) batch = [] for cls in selected_classes: # 从当前类中随机选择k_shot + n_query个样本 indices = np.random.choice( self.class_indices[cls], self.k_shot + self.n_query, replace=False) batch.extend(indices) yield batch def __len__(self): return self.n_episodes3.3 构建数据加载流程
将自定义数据集和采样器组合起来:
from torch.utils.data import DataLoader from torchvision import transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = MiniImageNetDataset(root='./miniImageNet', phase='train', transform=transform) sampler = EpisodeSampler(dataset.targets, n_way=5, k_shot=5, n_query=15, n_episodes=100) dataloader = DataLoader(dataset, batch_sampler=sampler)4. 高级技巧与优化
4.1 数据增强策略
少样本学习尤其需要强大的数据增强来防止过拟合:
from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(84), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transform = transforms.Compose([ transforms.Resize(92), transforms.CenterCrop(84), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])4.2 缓存加速
对于频繁访问的小样本,可以缓存到内存:
from torch.utils.data import Dataset class CachedDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.cache = {} def __getitem__(self, index): if index not in self.cache: self.cache[index] = self.dataset[index] return self.cache[index] def __len__(self): return len(self.dataset)4.3 分布式训练支持
修改采样器以支持分布式训练:
import torch.distributed as dist class DistributedEpisodeSampler(EpisodeSampler): def __init__(self, labels, n_way, k_shot, n_query, n_episodes, num_replicas=None, rank=None): if num_replicas is None: num_replicas = dist.get_world_size() if rank is None: rank = dist.get_rank() super().__init__(labels, n_way, k_shot, n_query, n_episodes // num_replicas) self.num_replicas = num_replicas self.rank = rank self.epoch = 0 def set_epoch(self, epoch): self.epoch = epoch def __iter__(self): # 确保每个epoch的随机选择在不同进程中一致 g = torch.Generator() g.manual_seed(self.epoch) indices = super().__iter__() # 只返回属于当前rank的episodes return (indices[i] for i in range(self.rank, len(indices), self.num_replicas))