PyG-InMemoryDataset加载数据
PyG-InMemoryDataset加载数据
引言
我们这次重点关注第一部分,并提供2020年1月1日的twitter数据作为实验数据。首先分析继承InMemoryDataset类的Planetoid类,然后说明该类实例化流程,并详细说明如何继承该类设计自己的数据加载类,最后利用twitter数据写一个自己的数据加载类。
源码分析
我们以PyG中加载Planetoid数据集的源码为例进行分析,说明如何通过继承InMemoryDataset
类来自定义一个数据可全部存储到内存的数据集类。
源码如下:
from typing import Optional, Callable, List |
可以发现总共实现了8个函数,其中前5个函数为核心:
__init()__
raw_file_names(self)
processed_file_names(self)
download(self)
process(self)
raw_dir(self)
processed_dir(self)
__repr__(self)
接下来分别介绍各个函数的作用和功能。
*__init()__
用于初始化类,包括两个部分:初始化父类(InMemoryDataset
类)、划分数据集:
几个核心参数用于初始化
InMemoryDataset
类:root
:字符串类型,存储数据集的文件夹的路径下。该文件夹下有两个文件夹:- 一个文件夹为记录在
raw_dir
,它用于存储未处理的文件,从网络上下载的数据集原始文件会被存放到这里; - 另一个文件夹记录在
processed_dir
,处理后的数据被保存到这里,以后从此文件夹下加载文件即可获得Data
对象。 - 注:
raw_dir
和processed_dir
是属性方法,我们可以自定义要使用的文件夹。
- 一个文件夹为记录在
transform
:函数类型,一个数据转换函数,它接收一个Data
对象并返回一个转换后的Data
对象。此函数在每一次数据获取过程中都会被执行。获取数据的函数首先使用此函数对Data
对象做转换,然后才返回数据。此函数应该用于数据增广(Data Augmentation)。该参数默认值为None
,表示不对数据做转换。pre_transform
:函数类型,一个数据转换函数,它接收一个Data
对象并返回一个转换后的Data
对象。此函数在Data
对象被保存到文件前调用。因此它应该用于只执行一次的数据预处理。该参数默认值为None
,表示不做数据预处理。pre_filter
:函数类型,一个检查数据是否要保留的函数,它接收一个Data
对象,返回此Data
对象是否应该被包含在最终的数据集中。此函数也在Data
对象被保存到文件前调用。该参数默认值为None
,表示不做数据检查,保留所有的数据。
划分数据集
将数据集别用掩码(
mask
)划分为训练、验证和测试集三个部分。
*raw_file_names()
属性方法,返回一个数据集原始文件的文件名列表,数据集原始文件应该能在raw_dir
文件夹中找到,否则调用download()
函数下载文件到raw_dir
文件夹。
*processed_file_names()
属性方法,返回一个存储处理过的数据的文件的文件名列表,存储处理过的数据的文件应该能在processed_dir
文件夹中找到,否则调用process()
函数对样本做处理,然后保存处理过的数据到processed_dir
文件夹下的文件里。
*download()
根据定义的url
属性下载数据集原始文件到raw_dir
文件夹。
*process()
调用读取数据函数,将数据包装成Data,然后处理数据,保存处理好的数据到processed_dir
文件夹下的文件。
raw_dir()
属性方法,原始数据存储的文件夹路径,我们可以自定义要使用的文件夹。
processed_dir()
属性方法,处理后数据存储的文件夹路径,我们可以自定义要使用的文件夹。
__repr__()
这个函数对应repr(object)这个功能。意思是当需要显示一个对象在屏幕上时,将这个对象的属性或者是方法整理成一个可以打印输出的格式。python
__repr__
的作用
也就是说:
print(Planetoid)
会打印出:
Planetoid()
其他符号和设计
:
、->
、Optional[Callable]
参数+冒号+数据类型,如def __init__(self, root: str, name: str, split: str = "public",num_train_per_class: int = 20, num_val:int = 500,num_test: int = 1000, transform: Optional[Callable] = None,pre_transform: Optional[Callable] = None)
,以及函数+"->"+数据类型,如def raw_dir(self) -> str:
参考:Python函数参数中的冒号与箭头、Python中typing模块与类型注解的使用方法
函数参数中的冒号是参数的类型建议符,告诉程序员希望传入的实参的类型。
函数后面跟着的箭头是函数返回值的类型建议符,用来说明该函数返回的值是什么类型。
类型建议符并非强制规定和检查,也就是说即使传入的实际参数与建议参数不符,也不会报错。
Optional[Callable]
,Optional
表示这个参数可以为空或已经声明的类型,Callable
表示是一个可调用类型的参数(函数等)。
@property
@property
修饰的方法为属性方法,如:
|
参考[python @property的介绍与使用](https://zhuanlan.zhihu.com/p/64487092):
python的@property是python的一种装饰器,是用来修饰方法的。可以使用@property装饰器来创建只读属性,@property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改。
Planetoid类实例化流程
执行以下语句时,该类的实例化流程如下。
dataset = Planetoid(root='dataset/PlanetoidPubMed',transform=NormalizeFeatures()) |
- 首先,检查数据原始文件是否已下载:
- 检查
self.raw_dir
目录下是否存在raw_file_names()
属性方法返回的每个文件, - 如有文件不存在,则调用
download()
方法执行原始文件下载。
- 检查
- 其次,检查数据是否经过处理:
- 首先,检查之前对数据做变换的方法:检查
self.processed_dir
目录下是否存在pre_transform.pt
文件:- 如果存在,意味着之前进行过数据变换,接着需要加载该文件,以获取之前所用的数据变换的方法,并检查它与当前
pre_transform
参数指定的方法是否相同,- 如果不相同则会报出一个警告,“The pre_transform argument differs from the one used in ……”。
- 如果存在,意味着之前进行过数据变换,接着需要加载该文件,以获取之前所用的数据变换的方法,并检查它与当前
- 其次,检查之前的样本过滤的方法:检查
self.processed_dir
目录下是否存在pre_filter.pt
文件:- 如果存在,则加载该文件并获取之前所用的样本过滤的方法,并检查它与当前
pre_filter
参数指定的方法是否相同,- 如果不相同则会报出一个警告,“The pre_filter argument differs from the one used in ……”。
- 如果存在,则加载该文件并获取之前所用的样本过滤的方法,并检查它与当前
- 接着,检查是否存在处理好的数据:检查
self.processed_dir
目录下是否存在self.processed_file_names
属性方法返回的所有文件,如有文件不存在,则需要执行以下的操作:- 调用
process()
方法,进行数据处理。 - 如果
pre_transform
参数不为None
,则调用pre_transform()
函数进行数据处理。 - 如果
pre_filter
参数不为None
,则进行样本过滤(此例子中不需要进行样本过滤,pre_filter
参数为None
)。 - 保存处理好的数据到文件,文件存储在processed_paths()属性方法返回的文件路径。如果将数据保存到多个文件中,则返回的路径有多个。
processed_paths()
属性方法是在基类(DataSet)中定义的,它对self.processed_dir
文件夹与processed_file_names()
属性方法的返回每一个文件名做拼接,然后返回。
- 最后保存新的
pre_transform.pt
文件和pre_filter.pt
文件,它们分别存储当前使用的数据处理方法和样本过滤方法。
- 调用
- 首先,检查之前对数据做变换的方法:检查
- 保证有预处理的文件后,在
self.data, self.slices = torch.load(self.processed_paths[0])
时从预处理文件路径中加载预处理后的数据。 - 在执行
data = dataset[0]
时才调用选择的transform
函数。
实验
现在让我们用twitter数据集来自定义一个数据类试试看吧。数据下载链接
数据示例:
index,o_place,d_place,year,month,day,cnt,o_lat,o_lon,d_lat,d_lon
196,37119,37025,2020,1,1,33,35.227,-80.827,35.41,-80.612
239,39123,39123,2020,1,1,7,41.529,-83.114,41.529,-83.114
462,13277,13277,2020,1,1,12,31.452,-83.508,31.462,-83.514
538,48201,48347,2020,1,1,4,29.795,-95.374,31.614,-94.649
import os.path as osp |
参考文献
[python @property的介绍与使用](https://zhuanlan.zhihu.com/p/64487092)