Pytorch 中的 Dataset、DataLoader
Pytorch
中的数据管理由 Dataset
和 DataLoader
两个类实现。
Pytorch 中的数据处理
Pytorch 数据处理的步骤:Pytorch
中处理和加载数据一般需要经历以下步骤
- 加载原始数据:从硬盘读取数据(例如图片、文本、CSV 文件)
- 预处理数据:将原始数据转换成 Tensor,进行必要的变换(如尺寸调整、归一化)
- 分批(Batching):将整个数据集分成一个个小批次(mini-batch)
- 迭代:将这些批次数据送入模型进行训练或评估
Dataset 与 DataLoader 的分工:Pytorch
中主要用 Dataset
和 DataLoader
两个类实现了上述操作
Dataset
:负责加载和预处理单个数据样本,使用__getitem__
读取每一个数据点。DataLoader
:负责从Dataset
中取出数据,并将其打包成批次以供训练
torch.utils.data.Dataset 类
Dataset 类基本使用:Dataset
是一个抽象类,我们需要创建自己的类来继承它,并实现两个方法:
__len__(self)
:返回数据集样本的总数。DataLoader
会用它来确定迭代的次数。_getitem__(self, idx)
:接收一个索引idx
,返回数据集中对应的一个样本。数据加载和预处理(transform
)一般在这里完成。
示例:从文件夹读取图片:假设我们的图片存储在如下目录中:
我们可以创建一个 CustomImageDataset
= # 用一个列表存储 (图片路径, 标签)
= 0
= 1
continue
# 保存 transform
=
return
# 1. 获取图片路径和标签
=
, # 2. 读取图片
=
# 3. 如果定义了 transform,则对图片进行处理
=
# 4. 返回处理后的图片和对应的标签
return ,
数据变换与增强 transform:原始数据很少能直接送入模型。图片需要被统一尺寸、转换为 Tensor、并进行归一化;文本需要被分词、转换为 ID。transform
就是用来完成这些预处理工作的。例如 torchvision.transforms
模块为图像处理提供了大量现成的工具
# 定义一系列变换
# 1. 将图片缩放到 224x224
# 2. 随机水平翻转(数据增强)
# 3. 转换为 Tensor(将 HWC 的 PIL Image 或 NumPy 数组转换为 CHW 的 Tensor,并将像素值从 [0, 255] 缩放到 [0.0, 1.0])
# 4. 归一化
=
# 实例化 Dataset 时传入 transform
=
# 现在,当我们通过 my_dataset[0] 获取数据时,得到的图片已经是经过上述所有变换后的 Tensor 了。
torch.utils.data.DataLoader 类
DataLoader 的基本使用:DataLoader
是一个迭代器,它从 Dataset
中自动抓取数据,打包成批次。其包含如下核心参数:
dataset
: 我们刚刚创建的Dataset
实例。batch_size
(int): 每个批次包含的样本数。默认为 1。shuffle
(bool): 是否在每个 epoch 开始时打乱数据顺序。训练时通常设为True
,验证/测试时设为False
。num_workers
(int): 用于数据加载的子进程数。0
表示在主进程中加载。大于0
的值可以显著加速数据加载,避免 GPU 等待。collate_fn
(callable, optional): 用于将多个样本合并成一个批次的函数。下面会重点讲解。pin_memory
(bool): 如果为True
,数据加载器会将张量复制到 CUDA 固定内存中,这可以加快数据到 GPU 的传输速度。
示例:创建和使用 DataLoader
# 假设 my_dataset 已经创建好了
# 创建 DataLoader
=
# 如何使用 DataLoader 进行迭代
= 10
# train_loader 是一个可迭代对象
# 在这里,batch_images 和 batch_labels 就是一个批次的数据
# batch_images 的形状通常是 (batch_size, channels, height, width) -> (32, 3, 224, 224)
# batch_labels 的形状通常是 (batch_size) -> (32)
# 将数据移动到 GPU (如果可用)
# batch_images, batch_labels = batch_images.to(device), batch_labels.to(device)
# ... 接下来是模型前向传播、计算损失、反向传播等步骤 ...
# 打印一个批次的形状看看
break # 这里只演示一个批次
break
collate_fn 自定义批次合并逻辑:DataLoader
的默认合并函数 default_collate
简单地使用 torch.stack()
将样本列表堆叠成一个批次,其使用场景是 Dataset
返回的每个样本都有相同的形状。但是,如果样本形状不一,default_collate
就会报错。最典型的例子是自然语言处理(NLP),每个句子的长度(词元数量)都不同。此时就需要自定义 collate_fn
,其工作逻辑为:
DataLoader
从Dataset
中取出batch_size
个样本,形成一个列表,例如[sample1, sample2, ..., sample_batch_size]
。DataLoader
将这个列表传递给collate_fn
函数。collate_fn
函数负责处理这个列表,并返回一个或多个已经正确打包好的批次 Tensor。
示例:处理不同长度的序列:假设我们有一个 Dataset
返回的是 (序列, 标签)
,其中序列是不同长度的 Tensor。
# 1. 创建一个返回可变长度序列的数据集
=
return
return
# 2. 定义自定义的 collate_fn
"""
Args:
batch: 一个列表,其中每个元素都是 Dataset 的一个返回值,即 (sequence, label)
例如: [(tensor([1,2,3]), 0), (tensor([4,5]), 1)]
"""
# 将序列和标签分开
=
=
# 对序列进行填充(padding),使它们长度一致
# pad_sequence 会自动找到批次中最长的序列,并将其他序列填充到该长度
# batch_first=True 表示返回的 Tensor 形状为 (batch_size, max_seq_length)
=
# 将标签列表转换为 Tensor
=
return ,
# 3. 在 DataLoader 中使用 custom_collate_fn
=
=
# 4. 迭代并查看结果
# 预期输出:
# Padded Sequences Batch:
# tensor([[1, 2, 3],
# [4, 5, 0]]) <-- 第二个序列被填充了一个0
# Shape: torch.Size([2, 3])
# Labels Batch:
# tensor([0, 1])
# Shape: torch.Size([2])
# --------------------
# Padded Sequences Batch:
# tensor([[6, 7, 8, 9],
# [10, 0, 0, 0]]) <-- 第二个序列被填充了三个0
# Shape: torch.Size([2, 4])
# Labels Batch:
# tensor([0, 1])
# Shape: torch.Size([2])
# --------------------