Pytorch与深度学习自查手册2-数据加载和预处理

数据加载

DataSet类

自定义一个继承 Dataset类的类 ,需要重写以下三个函数:

  1. __init__:传入数据,或者像下面一样直接在函数里加载数据;
  2. __len__:返回这个数据集一共有多少个item;
  3. __getitem__:返回一条训练数据,并将其转换成tensor。
  4. 通常还会在其中增加一个collate_fn函数,用于DataLoader,使用这个参数可以自己操作每个batch的数据,比如说在自然语言处理的命名实体识别任务中,在该函数中对每个batch中的样本都padding到同一长度等。
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
def __init__(self,path):
#加载数据
a = np.load("a.npy",allow_pickle=True)
b = np.load("b.npy",allow_pickle=True)
d = np.load("d.npy",allow_pickle=True)
c = np.load("c.npy")
self.x = list(zip(a,b,d,c))
self.y = ...
def __getitem__(self, idx):
assert idx < len(self.x)
return self.x[idx],self.y[idx]
def __len__(self):
return len(self.x)
def collate_fn(self,batch):
#……
pass
mydataset=Mydata(path, transform= transform_func)

collate_fn:如何取样本

一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label

# a simple custom collate function, just to show the idea 
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]

DataLoader类

DataLoader包括三个参数:

  1. dataset:传入的数据;
  2. shuffle = True:是否打乱数据;
  3. collate_fn函数:使用这个参数可以自己操作每个batch的数据。
  4. drop_last:告诉如何处理划分batch后剩下的最后不足一个batch的样本集合,True就抛弃,否则保留。
from torch.utils.data import DataLoader
dataset = Mydata()
#构建DataLoader
dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = dataset.collate_fn)

从DataLoader中取样本

#从dataloader中逐一取样本
train_features, train_labels = next(iter(train_dataloader))
#循环取样本
for X, y in dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break

数据预处理

自定义collate_fn函数传入

自定义collate_fn函数传入DataLoader。

transforms:对图片进行变换

注意:输入transforms的图片形状为[h,w,c],最后一维是channel,而transforms的输出是[c,h,w]

PyTorch 学习笔记(三):transforms的二十二个方法

transforms.ToTensor()转化数据类型

将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1] (numpy的数字类型必须为uint型,如果是其他类型不会进行归一化)

注意事项:归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。

from torchvision import transforms, utils
from PIL import Image
img=Image.image(img_path)
tensor_trans=transforms.ToTensor()
img_tensor=tensor_trans(img)

合并数据处理过程transforms.Compose()

trans_compose=transforms.Compose([transforms.Resize(),transforms.ToTensor()])
trans_compose(img)

正则化

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Resize(image_size)

随机裁剪/中心裁剪

transforms. RandomCrop((512,1000))
transforms.CenterCrop(image_size)

整数转one-hot

num_class=10
target_transform = Lambda(lambda y: torch.zeros(
num_class, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

参考资料

图神经网络的下游任务3-图分类 | 冬于的博客 (ifwind.github.io)

PyTorch 学习笔记(三):transforms的二十二个方法