Pytorch数据读取与预处理如何为?
摘要:在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加快以后的代码效率。
在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加快以后的代码效率。
根据数据是否一次性读取完,将DEMO分为:
1、串行式读取。也就是一次性读取完所有需要的数据到内存,模型训练时不会再访问外存。通常用在内存足够的情况下使用,速度更快。
2、并行式读取。也就是边训练边读取数据。通常用在内存不够的情况下使用,会占用计算资源,如果分配的好的话,几乎不损失速度。
Pytorch官方的数据提取方式尽管方便编码,但由于它提取数据方式比较死板,会浪费资源,下面对其进行分析。
串行式读取
DEMO代码
import torch
from torch.utils.data import Dataset,DataLoader
class MyDataSet(Dataset):# ————1————
def __init__(self):
self.data = torch.tensor(range(10)).reshape([5,2])
self.label = torch.tensor(range(5))
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return len(self.data)
my_data_set = MyDataSet()# ————2————
my_data_loader = DataLoader(
dataset=my_data_set, # ————3————
batch_size=2, # ————4————
shuffle=True, # ————5————
sampler=None, # ————6————
batch_sampler=None, # ————7————
num_workers=0 , # ————8————
collate_fn=None, # ————9————
pin_memory=True, # ————10————
drop_last=True # ————11————
)
for i in my_data_loader: # ————12————
print(i)
注释处解释如下:
1、重写数据集类,用于保存数据。除了 __init__() 外,必须实现 __getitem__() 和 __len__() 两个方法。前一个方法用于输出索引对应的数据。后一个方法用于获取数据集的长度。
2~5、 2准备好数据集后,传入DataLoader来迭代生成数据。前三个参数分别是传入的数据集对象、每次获取的批量大小、是否打乱数据集输出。
6、采样器,如果定义这个,shuffle只能设置为False。所谓采样器就是用于生成数据索引的可迭代对象,比如列表。因此,定义了采样器,采样都按它来,shuffle再打乱就没意义了。
7、批量采样器,如果定义这个,batch_size、shuffle、sampler、drop_last都不能定义。实际上,如果没有特殊的数据生成顺序的要求,采样器并没有必要定义。torch.utils.data 中的各种 Sampler 就是采样器类,如果需要,可以使用它们来定义。
8、用于生成数据的子进程数。默认为0,不并行。
9、拼接多个样本的方法,默认是将每个batch的数据在第一维上进行拼接。这样可能说不清楚,并且由于这里可以探究一下获取数据的速度,后面再详细说明。
10、是否使用锁页内存。用的话会更快,内存不充足最好别用。
11、是否把最后小于batch的数据丢掉。
12、迭代获取数据并输出。
速度探索
首先看一下DEMO的输出:
输出了两个batch的数据,每组数据中data和label都正确排列,符合我们的预期。
