news 2026/4/18 8:02:18

Day 42 Dataset 和 Dataloader 类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 42 Dataset 和 Dataloader 类

@浙大疏锦行

一、核心定位

核心角色核心作用
Dataset「数据容器」/「数据加工厂」定义单条数据的读取、预处理逻辑(如从 CSV 读一行、编码、填充、标准化),支持按索引取数
DataLoader「数据搬运工」/「批量调度器」封装Dataset,实现批量加载、数据打乱、多线程读取、分批迭代,解决显存和效率问题

简单来说:

  • Dataset解决 “单条数据怎么来” 的问题;
  • DataLoader解决 “怎么把数据批量喂给模型” 的问题。

二、Dataset类详解

PyTorch 提供了内置Dataset(如TensorDatasetImageFolder),但实际项目中,必须自定义Dataset,需继承torch.utils.data.Dataset并实现两个核心方法:

  1. __len__():返回数据集的总条数(让DataLoader知道有多少数据);
  2. __getitem__(idx):根据索引idx返回单条数据(特征 + 标签),是预处理的核心。

1. 自定义Dataset实战

import torch from torch.utils.data import Dataset, DataLoader import pandas as pd import numpy as np from sklearn.preprocessing import StandardScaler class CreditDefaultDataset(Dataset): """ 信用违约预测自定义Dataset :param data_path: CSV数据路径 :param is_train: 是否为训练集(用于区分训练/测试的标准化) :param scaler: 标准化器(训练集拟合,测试集复用) """ def __init__(self, data_path, is_train=True, scaler=None): self.data_path = data_path self.is_train = is_train self.scaler = scaler # 1. 读取数据 + 预处理(复用你之前的逻辑) self.data = self._preprocess_data() # 2. 分离特征和标签 self.X = self.data.drop(['Credit Default'], axis=1).values # 特征数组 self.y = self.data['Credit Default'].values # 标签数组 # 3. 标准化(训练集拟合scaler,测试集复用) if self.is_train: self.scaler = StandardScaler() self.X = self.scaler.fit_transform(self.X) else: assert self.scaler is not None, "测试集必须传入训练集的scaler!" self.X = self.scaler.transform(self.X) def _preprocess_data(self): """封装你之前的预处理逻辑(编码、填充缺失值)""" data = pd.read_csv(self.data_path) data2 = pd.read_csv(self.data_path) # 1. 字符串变量编码 # Home Ownership 标签编码 home_mapping = {'Own Home':1, 'Rent':2, 'Have Mortgage':3, 'Home Mortgage':4} data['Home Ownership'] = data['Home Ownership'].map(home_mapping) # Years in current job 标签编码 years_mapping = {'< 1 year':1, '1 year':2, '2 years':3, '3 years':4, '4 years':5, '5 years':6, '6 years':7, '7 years':8, '8 years':9, '9 years':10, '10+ years':11} data['Years in current job'] = data['Years in current job'].map(years_mapping) # Purpose 独热编码 + bool转int data = pd.get_dummies(data, columns=['Purpose']) new_cols = [col for col in data.columns if col not in data2.columns] for col in new_cols: data[col] = data[col].astype(int) # Term 映射 + 重命名 term_mapping = {'Short Term':0, 'Long Term':1} data['Term'] = data['Term'].map(term_mapping) data.rename(columns={'Term': 'Long Term'}, inplace=True) # 2. 缺失值填充(连续特征用众数) cont_feats = data.select_dtypes(include=['int64', 'float64']).columns.tolist() for feat in cont_feats: mode_val = data[feat].mode()[0] data[feat].fillna(mode_val, inplace=True) return data def __len__(self): """返回数据集总条数(必须实现)""" return len(self.X) def __getitem__(self, idx): """ 根据索引返回单条数据(必须实现) :param idx: 数据索引(int) :return: (特征tensor, 标签tensor) """ # 取单条数据(numpy数组) x = self.X[idx] y = self.y[idx] # 转换为PyTorch张量(MLP需要float32类型) x_tensor = torch.from_numpy(x).float() y_tensor = torch.from_numpy(np.array(y)).float() # 二分类标签用float(适配BCELoss) return x_tensor, y_tensor

2.Dataset核心方法解释

方法作用
__init__初始化:读取数据、预处理、分离特征 / 标签、标准化(核心预处理逻辑都在这里)
__len__返回数据总条数,DataLoader会用这个方法知道 “有多少批数据”
__getitem__按索引取单条数据,是DataLoader批量加载的基础(每次取 1 条,再拼 batch)
自定义方法(如_preprocess_data封装预处理逻辑,让代码更整洁(非必须,但推荐)

三、DataLoader类详解

DataLoaderDataset的 “上层封装”,核心作用是Dataset中的单条数据拼成批次,并提供高效读取能力。

1. 核心参数

参数作用训练集 / 测试集建议
dataset传入自定义的Dataset实例(必须)-
batch_size每批数据的条数(如 32、64)训练集:32/64;测试集:可更大(如 128)
shuffle是否打乱数据(避免模型学习顺序)训练集:True;测试集:False
num_workers多线程读取数据(加速)Windows:0;Linux/Mac:4/8(根据 CPU 核数)
drop_last是否丢弃最后一批不足batch_size的数据训练集:True;测试集:False
pin_memory是否锁定内存(GPU 训练时加速数据传输)GPU 训练:True;CPU:False

2.DataLoader实战

# 数据路径(替换为你的实际路径) DATA_PATH = r"E:\study\PythonStudy\python60-days-challenge-master\data.csv" # Step1:创建训练集Dataset(拟合scaler) train_dataset = CreditDefaultDataset(data_path=DATA_PATH, is_train=True) # 提取训练集的scaler(供测试集复用) train_scaler = train_dataset.scaler # Step2:划分训练/测试集(可选:如果Dataset已包含全量数据,这里拆分) # 注意:也可以在Dataset中直接划分,这里用切片示例 train_size = int(0.8 * len(train_dataset)) test_size = len(train_dataset) - train_size train_subset, test_subset = torch.utils.data.random_split( train_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42) ) # 测试集Dataset(复用训练集的scaler) test_dataset = CreditDefaultDataset(data_path=DATA_PATH, is_train=False, scaler=train_scaler) # Step3:创建DataLoader(核心) train_loader = DataLoader( dataset=train_subset, batch_size=32, shuffle=True, # 训练集打乱 num_workers=0, # Windows建议设0,避免多线程报错 drop_last=True, # 丢弃最后一批不足32条的数据 pin_memory=True if torch.cuda.is_available() else False # GPU加速 ) test_loader = DataLoader( dataset=test_subset, batch_size=64, shuffle=False, # 测试集不打乱 num_workers=0, drop_last=False, pin_memory=True if torch.cuda.is_available() else False ) # Step4:迭代DataLoader(训练/测试时的核心用法) # 示例:遍历训练集批次 print("===== 训练集批次示例 =====") for batch_idx, (x_batch, y_batch) in enumerate(train_loader): print(f"批次 {batch_idx+1}:") print(f" 特征形状:{x_batch.shape} (batch_size={x_batch.shape[0]}, 特征数={x_batch.shape[1]})") print(f" 标签形状:{y_batch.shape}") # 训练时:将批次数据移到GPU → 前向传播 → 反向传播 if batch_idx == 2: # 只打印前3批 break # 示例:遍历测试集批次 print("\n===== 测试集批次示例 =====") for batch_idx, (x_batch, y_batch) in enumerate(test_loader): print(f"批次 {batch_idx+1}:特征形状={x_batch.shape},标签形状={y_batch.shape}") if batch_idx == 1: break

3.DataLoader迭代逻辑解释

  • enumerate(train_loader)会逐批返回(批次索引, (特征批次, 标签批次))
  • 特征批次形状:(batch_size, 特征数)(如 32 个样本,20 个特征 →(32, 20));
  • 标签批次形状:(batch_size,)(如 32 个样本 →(32,));
  • 训练时,每批数据会被喂给模型:outputs = model(x_batch)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/16 20:43:16

安克创新的AB面:创始人分红过亿,存货却压垮现金流

"为何渴求港股二次上市&#xff1f;" 作者 | 王冲和 编辑 | 卢旭成 前不久&#xff0c;安克创新正式向港交所递交了主板上市申请&#xff0c;这个“充电宝第一股”再次被世人关注。 早在2020年8月24日&#xff0c;安克创新已经登陆深交所创业板&#xff0c;上市首…

作者头像 李华
网站建设 2026/4/17 21:27:19

农业物联网通信难题如何破解:3步实现Agent间无缝协同

第一章&#xff1a;农业物联网Agent通信的挑战与演进在现代农业物联网&#xff08;IoT&#xff09;系统中&#xff0c;分布式智能设备&#xff08;即Agent&#xff09;之间的高效通信是实现精准农业的核心。随着传感器网络、边缘计算和自动化农机具的广泛应用&#xff0c;农业场…

作者头像 李华
网站建设 2026/3/22 6:03:00

【首发】Agentic RAN:智能体时代的下一代无线接入网

【摘要】智能体时代的无线接入网应该是什么样的&#xff1f;本文首创性地提出一个全新的概念和定义“Agentic RAN”&#xff1a;以智能体实现无线接入网的自感知、自决策、自执行优化&#xff0c;并在基站/汇聚侧提供边缘AI算力与能力编排&#xff0c;构建“云—边—端”一体的…

作者头像 李华
网站建设 2026/4/15 22:38:49

边缘Agent部署必须掌握的7个关键技术点(附最佳实践)

第一章&#xff1a;边缘Agent部署的核心挑战在现代分布式系统架构中&#xff0c;边缘Agent作为连接中心平台与终端设备的桥梁&#xff0c;承担着数据采集、本地决策和指令执行等关键任务。然而&#xff0c;其部署过程面临诸多技术难题&#xff0c;尤其是在资源受限、网络不稳定…

作者头像 李华
网站建设 2026/4/17 8:40:21

小程序毕设选题推荐:基于微信小程序的集换社卡牌的交易系统基于springboot+微信小程序的集换社卡牌的交易系统小程序【附源码、mysql、文档、调试+代码讲解+全bao等】

博主介绍&#xff1a;✌️码农一枚 &#xff0c;专注于大学生项目实战开发、讲解和毕业&#x1f6a2;文撰写修改等。全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围&#xff1a;&am…

作者头像 李华
网站建设 2026/4/17 17:58:22

【课程设计/毕业设计】基于Android的乡村研学旅行APP系统app小程序基于springboot+Android的研学旅行服务平台APP小程序设计【附源码、数据库、万字文档】

博主介绍&#xff1a;✌️码农一枚 &#xff0c;专注于大学生项目实战开发、讲解和毕业&#x1f6a2;文撰写修改等。全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围&#xff1a;&am…

作者头像 李华