在 PyTorch 中,使用 torch.utils.data.DataLoader 类可以实现批量的数据集加载,在我们训练模型中非常常用,其功能也确实比较强度大。由于其参数比较多,我们将会对其用法进行详解。
-
DataLoader 的基本使用
-
DataLoader 的 collate_fn 参数
-
DataLoader 的 sampler 参数
-
DataLoader 的基本使用 {#title-0} ==============================
使用 DataLoader 前,我们先实现一个用于获得数据的类,假设名字为: MyDataset,其需要实现以下几个方法:
- init 方法用于对类对象进行初始化
- len 方法用于返回数据集中样本的数量
- getitem 方法用于根据索引返回一条样本
接下来,将我们自己构造的 MyDataset 实例对象交给 DataLoader,由其对我们的数据集对象进行封装返回一个数据加载器。DataLoader 的 shuffle 参数可以指定是否打乱原有的数据集顺序,batch_size 参数用于指定每次加载的批次样本数量,drop_last 参数指定最后一组不够批次数量的样本是否丢弃。
shuffle 参数默认值是 False ,batch_size 参数默认值是 1,drop_last 参数默认值是 False。
import torch
from torch.utils.data import DataLoader
class MyDataset:
def __init__(self, x, y):
self.x = x
self.y = y
self.sample_number = len(self.y)
def __len__(self):
return self.sample_number
def __getitem__(self, idx):
# 修正 idx 范围为 [0, idx]
idx = min(max(idx, 0), self.sample_number - 1)
# 返回一组样本
return self.x[idx], self.y[idx]
def test():
# 构造数据集
x = torch.arange(21).reshape(21, 1)
y = torch.arange(21)
# 初始化数据集
dataset = MyDataset(x, y)
# 初始化数据加载器
dataloader = DataLoader(dataset, shuffle=True, batch_size=8, drop_last=True)
for tx, ty in dataloader:
print(tx)
if __name__ == '__main__':
test()
程序输出结果:
tensor([[ 7],
[14],
[11],
[15],
[ 9],
[12],
[ 3],
[ 5]])
tensor([[ 8],
[10],
[ 1],
[ 4],
[13],
[ 6],
[ 0],
[ 2]])
从程序可以看到,我们的样本数量为 21,每 8 个样本组成一个批次,由于设置了 drop_last 为 True,所以共打印了 2 个批次的训练数据,并且由于 shuffle 参数被设置为 True,每一个批次的样本都是被打乱的,并不是按照原来的样本数量。 注意:在上面的例子中,MyDataset 类并没有继承 torch.utils.data.Dataset 类。
- DataLoader 的 collate_fn 参数 {#title-1} ========================================
collate_fn 参数用于接收用于传递的一个函数。DataLoader 会从数据集中获得一个批次的数据,然后将该批次数据再传递到 collate_fn 指向的函数中进行二次处理。
import torch
from torch.utils.data import DataLoader
class MyDataset:
def __init__(self, x, y):
self.x = x
self.y = y
self.sample_number = len(self.y)
def __len__(self):
return self.sample_number
def __getitem__(self, idx):
# 修正 idx 范围为 [0, idx]
idx = min(max(idx, 0), self.sample_number - 1)
# 返回一组样本
return self.x[idx], self.y[idx]
def secondary_processing(data):
# 在此函数中可以对数据集进行二次处理
# 传递进行的批次数据 [(样本1, 目标值1) ... (样本2, 目标值2)]
feature = []
target = []
for x, y in data:
feature.append(x.tolist())
target.append(y.item())
feature = torch.tensor(feature)
target = torch.tensor(target)
return feature, target
def test():
# 构造数据集
x = torch.arange(16).reshape(16, 1)
y = torch.arange(100, 116)
# 初始化数据集
dataset = MyDataset(x, y)
# 初始化数据加载器
dataloader = DataLoader(dataset,
shuffle=True,
batch_size=8,
collate_fn=secondary_processing)
for tx, ty in dataloader:
print(tx)
if __name__ == '__main__':
test()
- DataLoader 的 sampler 参数 {#title-2} =====================================
sampler 用于设置如何从数据集中提取样本,即: 数据采样策略。如果指定该参数,则 shuffle 参数则会被忽略。在 DataLoader 中内置了几种采样器:
- SequentialSampler 采样策略表示按照样本的顺序进行采样
- BatchSampler 采样策略表示按照指定的批次索引进行采样
- RandomSampler 采样策略表示进行随机采样、以及是否允许有放回的采样
- SubsetRandomSampler 采样策略表示按照指定的集合或者索引列表进行随机采样
- WeightedRandomSampler 采样策略表示按照指定的概率进行随机采样
接下来,我们代码演示下上面一些采样策略的用法:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler
from torch.utils.data import SubsetRandomSampler
from torch.utils.data import WeightedRandomSampler
class MyDataset:
def __init__(self, x, y):
self.x = x
self.y = y
self.sample_number = len(self.y)
def __len__(self):
return self.sample_number
def __getitem__(self, idx):
if isinstance(idx, int):
idx = min(max(idx, 0), self.sample_number - 1)
return self.x[idx], self.y[idx]
if isinstance(idx, list):
xs = []
ys = []
for i in idx:
xs.append(self.x[i])
ys.append(self.y[i])
return xs, ys
# 1. SequentialSampler 的用法
def get_dataloader1(x, y):
dataset = MyDataset(x, y)
# SequentialSampler 需要将 dataset 作为参数
# SequentialSampler 获得原始数据的索引
sampler = SequentialSampler(dataset)
# 由于 SequentialSampler 没有指定 batch_size 的参数, 需要在 DataLoader 中设置
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)
return dataloader
# 2. BatchSampler 的用法
def get_dataloader2(x, y):
dataset = MyDataset(x, y)
# 在指定索引列表, 根据 batch_size 产生顺序产生批次数据
# [3, 4, 5]、[7, 8, 9]、[10] 作为一个批次
sampler = BatchSampler([3, 4, 5, 7, 8, 9, 10], batch_size=3, drop_last=False)
# 由于 BatchSampler 指定了 batch_size, 在 DataLoader 中不需要指定
dataloader = DataLoader(dataset, sampler=sampler)
return dataloader
# 3. RandomSampler 的用法
def get_dataloader3(x, y):
dataset = MyDataset(x, y)
# RandomSampler 需要将 dataset 作为参数
# RandomSampler 获得原始数据的索引
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)
return dataloader
def get_dataloader4(x, y):
dataset = MyDataset(x, y)
# 随机从 [3, 4, 5, 7, 8, 9, 10] 中产生批次
sampler = SubsetRandomSampler([3, 4, 5, 7, 8, 9, 10])
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
return dataloader
def get_dataloader5(x, y):
dataset = MyDataset(x, y)
# 随机从前 num_samples 个样本中,根据概率值中产生批次
# replacement 参数表示否重复采样
sampler = WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6, 0.4, 0.7, 3.0, 0.6],
num_samples=10,
replacement=False)
dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
return dataloader
def test():
x = torch.arange(16).reshape(16, 1)
y = torch.arange(100, 116)
dataloader = get_dataloader5(x,y)
for tx, ty in dataloader:
print(tx, ty)
if __name__ == '__main__':
test()
- DataLoader Dataset 和 sampler 的关系 {#title-3} ==============================================
DataLoader 使用 sampler 产生数据索引,根据索引从 Dataset 中获得批次数据。