Pytorch与深度学习自查手册2-数据加载和预处理
Pytorch与深度学习自查手册2-数据加载和预处理
数据加载
DataSet类
自定义一个继承 Dataset类的类 ,需要重写以下三个函数:
__init__
:传入数据,或者像下面一样直接在函数里加载数据;__len__
:返回这个数据集一共有多少个item;__getitem__
:返回一条训练数据,并将其转换成tensor。- 通常还会在其中增加一个
collate_fn函数
,用于DataLoader
,使用这个参数可以自己操作每个batch的数据,比如说在自然语言处理的命名实体识别任务中,在该函数中对每个batch中的样本都padding
到同一长度等。
import torch |
collate_fn:如何取样本
一般的,默认的collate_fn
函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn
函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label
# a simple custom collate function, just to show the idea |
DataLoader类
DataLoader包括三个参数:
dataset
:传入的数据;shuffle
= True:是否打乱数据;collate_fn
函数:使用这个参数可以自己操作每个batch的数据。drop_last
:告诉如何处理划分batch后剩下的最后不足一个batch的样本集合,True就抛弃,否则保留。
from torch.utils.data import DataLoader |
从DataLoader中取样本
#从dataloader中逐一取样本 |
数据预处理
自定义collate_fn函数传入
自定义collate_fn
函数传入DataLoader。
transforms:对图片进行变换
注意:输入transforms的图片形状为[h,w,c]
,最后一维是channel,而transforms的输出是[c,h,w]
。
PyTorch 学习笔记(三):transforms的二十二个方法
transforms.ToTensor()转化数据类型
将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1] (numpy的数字类型必须为uint型,如果是其他类型不会进行归一化)
注意事项:归一化至[0-1]是直接除以255,若自己的ndarray数据尺度有变化,则需要自行修改。
from torchvision import transforms, utils |
合并数据处理过程transforms.Compose()
trans_compose=transforms.Compose([transforms.Resize(),transforms.ToTensor()]) |
正则化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
随机裁剪/中心裁剪
transforms. RandomCrop((512,1000)) |
整数转one-hot
num_class=10 |
参考资料
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 冬于的博客!
评论