news 2026/4/22 2:13:59

从零构建你的第一个少样本学习数据集:用PyTorch和ImageFolder玩转miniImageNet

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从零构建你的第一个少样本学习数据集:用PyTorch和ImageFolder玩转miniImageNet

从零构建你的第一个少样本学习数据集:用PyTorch和ImageFolder玩转miniImageNet

当你想复现一篇经典的少样本学习论文时,最令人头疼的往往不是模型实现,而是如何准备那个"标准"的数据集。本文将带你从零开始,一步步构建miniImageNet数据集,并实现支持episode采样的数据加载流程。

1. 理解miniImageNet的前世今生

miniImageNet最初由Vinyals等人在2016年的《Matching Networks for One Shot Learning》论文中提出,作为ImageNet的一个子集,专门用于少样本学习研究。它包含了100个类别,每个类别有600张图片,总共有6万张84x84大小的图像。

为什么选择miniImageNet?

  • 规模适中:相比完整的ImageNet(1000类,120万张),miniImageNet更适合快速实验
  • 标准划分:已被社区广泛接受,便于结果对比
  • 挑战性:保留了ImageNet的多样性,适合测试模型的泛化能力

注意:不同论文中使用的miniImageNet划分可能不同,本文采用Ravi等人在《Optimization as a model for few-shot learning》中的划分方式:64类训练,16类验证,20类测试。

2. 从原始ImageNet到miniImageNet

2.1 获取原始ImageNet数据

由于ImageNet官方已不再开放公开下载,我们可以通过学术种子获取ILSVRC2012数据集:

# 训练集 (138GB) wget http://academictorrents.com/download/a306397ccf9c2ead27155983c254227c0fd938e2.torrent # 验证集 (6.3GB) wget http://academictorrents.com/download/5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5.torrent

下载完成后,验证文件完整性:

md5sum ILSVRC2012_img_train.tar # 应输出: 1d675b47d978889d74fa0da5fadfb00e md5sum ILSVRC2012_img_val.tar # 应输出: 29b22e2961454d5413ddabcf34fc5622

2.2 解压与预处理

ImageNet的训练集是按类别分tar包存储的,我们需要先解压外层tar,再逐个解压内部tar:

import os import tarfile def extract_tar(tar_path, output_dir): with tarfile.open(tar_path) as tar: tar.extractall(path=output_dir) # 解压内部tar文件 for class_tar in os.listdir(output_dir): if class_tar.endswith('.tar'): class_name = class_tar.split('.')[0] class_dir = os.path.join(output_dir, class_name) os.makedirs(class_dir, exist_ok=True) with tarfile.open(os.path.join(output_dir, class_tar)) as tar: tar.extractall(path=class_dir) os.remove(os.path.join(output_dir, class_tar))

2.3 构建miniImageNet子集

我们将使用Ravi提供的划分方案,从1000类中选出100类,并按64/16/20的比例划分训练/验证/测试集:

import pandas as pd from shutil import copyfile import cv2 def build_mini_imagenet(original_dir, output_dir, split_csv_dir, size=84): splits = ['train', 'val', 'test'] for split in splits: split_dir = os.path.join(output_dir, split) os.makedirs(split_dir, exist_ok=True) df = pd.read_csv(os.path.join(split_csv_dir, f'{split}.csv')) for _, row in df.iterrows(): src_path = find_original_path(original_dir, row['filename']) dst_dir = os.path.join(split_dir, row['label']) os.makedirs(dst_dir, exist_ok=True) if size > 0: img = cv2.imread(src_path) img = cv2.resize(img, (size, size)) cv2.imwrite(os.path.join(dst_dir, row['filename']), img) else: copyfile(src_path, os.path.join(dst_dir, row['filename']))

3. 构建支持Episode采样的数据集类

3.1 自定义ImageFolder数据集

PyTorch的ImageFolder默认按文件夹顺序分配标签,我们需要自定义以保持与原始标签一致:

from torchvision.datasets import ImageFolder import numpy as np class MiniImageNetDataset(ImageFolder): def __init__(self, root, phase='train', transform=None): self.phase = phase self.class_to_idx = self._load_class_mapping() super().__init__(root=os.path.join(root, phase), transform=transform) def _load_class_mapping(self): # 加载CSV文件建立标签映射 df = pd.read_csv(f'./splits/{self.phase}.csv') unique_labels = df['label'].unique() return {label: idx for idx, label in enumerate(unique_labels)} def find_classes(self, directory): return list(self.class_to_idx.keys()), self.class_to_idx

3.2 实现Episode采样器

少样本学习通常采用episode训练方式,每个episode包含N个类,每个类K个样本:

from torch.utils.data.sampler import Sampler import numpy as np class EpisodeSampler(Sampler): def __init__(self, labels, n_way, k_shot, n_query, n_episodes): self.labels = labels self.n_way = n_way self.k_shot = k_shot self.n_query = n_query self.n_episodes = n_episodes # 按类别组织样本索引 self.class_indices = {} for idx, label in enumerate(labels): if label not in self.class_indices: self.class_indices[label] = [] self.class_indices[label].append(idx) def __iter__(self): for _ in range(self.n_episodes): # 随机选择n_way个类别 selected_classes = np.random.choice( list(self.class_indices.keys()), self.n_way, replace=False) batch = [] for cls in selected_classes: # 从当前类中随机选择k_shot + n_query个样本 indices = np.random.choice( self.class_indices[cls], self.k_shot + self.n_query, replace=False) batch.extend(indices) yield batch def __len__(self): return self.n_episodes

3.3 构建数据加载流程

将自定义数据集和采样器组合起来:

from torch.utils.data import DataLoader from torchvision import transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = MiniImageNetDataset(root='./miniImageNet', phase='train', transform=transform) sampler = EpisodeSampler(dataset.targets, n_way=5, k_shot=5, n_query=15, n_episodes=100) dataloader = DataLoader(dataset, batch_sampler=sampler)

4. 高级技巧与优化

4.1 数据增强策略

少样本学习尤其需要强大的数据增强来防止过拟合:

from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(84), transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transform = transforms.Compose([ transforms.Resize(92), transforms.CenterCrop(84), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])

4.2 缓存加速

对于频繁访问的小样本,可以缓存到内存:

from torch.utils.data import Dataset class CachedDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.cache = {} def __getitem__(self, index): if index not in self.cache: self.cache[index] = self.dataset[index] return self.cache[index] def __len__(self): return len(self.dataset)

4.3 分布式训练支持

修改采样器以支持分布式训练:

import torch.distributed as dist class DistributedEpisodeSampler(EpisodeSampler): def __init__(self, labels, n_way, k_shot, n_query, n_episodes, num_replicas=None, rank=None): if num_replicas is None: num_replicas = dist.get_world_size() if rank is None: rank = dist.get_rank() super().__init__(labels, n_way, k_shot, n_query, n_episodes // num_replicas) self.num_replicas = num_replicas self.rank = rank self.epoch = 0 def set_epoch(self, epoch): self.epoch = epoch def __iter__(self): # 确保每个epoch的随机选择在不同进程中一致 g = torch.Generator() g.manual_seed(self.epoch) indices = super().__iter__() # 只返回属于当前rank的episodes return (indices[i] for i in range(self.rank, len(indices), self.num_replicas))
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/22 2:07:51

拆解一个USB3.0扩展坞:从VL817/VL822芯片Layout看消费电子的成本与性能博弈

拆解USB3.0扩展坞:VL817芯片Layout中的成本与性能平衡术 当我们拆开一个售价不到百元的USB3.0扩展坞,往往会惊讶于其内部结构的精简——这背后是消费电子领域永恒的成本与性能博弈。本文将以威盛VL817芯片方案为例,通过实际拆解和信号测试&a…

作者头像 李华
网站建设 2026/4/22 2:06:48

VSCode集成AI编程助手提升开发效率指南

1. 为什么要在VSCode中集成AI编程助手作为每天与代码打交道的开发者,我们经常会在编码过程中遇到各种问题:某个API的用法记不清了、需要优化一段性能不佳的代码、或者想快速生成一些样板代码。传统做法是切出编辑器去搜索引擎查找,这种上下文…

作者头像 李华
网站建设 2026/4/22 2:01:27

浏览器端深度学习模型部署:TensorFlow.js实战

1. 项目概述:浏览器端深度学习模型实战在浏览器里直接跑深度学习模型?这听起来像是2017年之前的科幻场景。但当我第一次用TensorFlow.js在Chrome里加载VGG16处理图片分类时,页面刷新后3秒内就显示出"金毛犬-置信度92%"的结果&#…

作者头像 李华
网站建设 2026/4/22 1:57:37

如何快速实现Android PDF打印:面向开发者的完整指南

如何快速实现Android PDF打印:面向开发者的完整指南 【免费下载链接】AndroidPdfViewer Android view for displaying PDFs rendered with PdfiumAndroid 项目地址: https://gitcode.com/gh_mirrors/an/AndroidPdfViewer 还在为Android应用中PDF打印功能而烦…

作者头像 李华