PyG-InMemoryDataset加载数据

引言

我们这次重点关注第一部分,并提供2020年1月1日的twitter数据作为实验数据。首先分析继承InMemoryDataset类的Planetoid类,然后说明该类实例化流程,并详细说明如何继承该类设计自己的数据加载类,最后利用twitter数据写一个自己的数据加载类。

源码分析

我们以PyG中加载Planetoid数据集的源码为例进行分析,说明如何通过继承InMemoryDataset类来自定义一个数据可全部存储到内存的数据集类。

源码如下:

from typing import Optional, Callable, List
import os.path as osp
import torch
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.io import read_planetoid_data

class Planetoid(InMemoryDataset):

url = 'https://github.com/kimiyoung/planetoid/raw/master/data'
def __init__(self, root: str, name: str, split: str = "public",
num_train_per_class: int = 20, num_val: int = 500,
num_test: int = 1000, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None):
self.name = name

super().__init__(root, transform, pre_transform) #等同于super(Planetoid,self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])#InMemoryDataset继承于Dataset类,processed_paths是Dataset类的属性,有self.processed_dir拼接得到

# 将数据集划分为训练、验证和测试集
self.split = split
assert self.split in ['public', 'full', 'random']

if split == 'full':
data = self.get(0)
data.train_mask.fill_(True)
data.train_mask[data.val_mask | data.test_mask] = False
self.data, self.slices = self.collate([data])

elif split == 'random':
data = self.get(0)
data.train_mask.fill_(False)
for c in range(self.num_classes):
idx = (data.y == c).nonzero(as_tuple=False).view(-1)
idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
data.train_mask[idx] = True

remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1)
remaining = remaining[torch.randperm(remaining.size(0))]

data.val_mask.fill_(False)
data.val_mask[remaining[:num_val]] = True

data.test_mask.fill_(False)
data.test_mask[remaining[num_val:num_val + num_test]] = True

self.data, self.slices = self.collate([data])

@property
def raw_dir(self) -> str:
return osp.join(self.root, self.name, 'raw')

@property
def processed_dir(self) -> str:
return osp.join(self.root, self.name, 'processed')

@property
def raw_file_names(self) -> List[str]:
names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
return [f'ind.{self.name.lower()}.{name}' for name in names]

@property
def processed_file_names(self) -> str:
return 'data.pt'

def download(self):
for name in self.raw_file_names:
download_url('{}/{}'.format(self.url, name), self.raw_dir)

def process(self):
data = read_planetoid_data(self.raw_dir, self.name)#调用函数读取数据,包装成Data
data = data if self.pre_transform is None else self.pre_transform(data)
torch.save(self.collate([data]), self.processed_paths[0])

def __repr__(self) -> str:
return f'{self.name}()'

可以发现总共实现了8个函数,其中前5个函数为核心:

  1. __init()__
  2. raw_file_names(self)
  3. processed_file_names(self)
  4. download(self)
  5. process(self)
  6. raw_dir(self)
  7. processed_dir(self)
  8. __repr__(self)

接下来分别介绍各个函数的作用和功能。

*__init()__

用于初始化类,包括两个部分:初始化父类(InMemoryDataset类)、划分数据集:

  1. 几个核心参数用于初始化InMemoryDataset类:

    • root:字符串类型,存储数据集的文件夹的路径下。该文件夹下有两个文件夹:
      • 一个文件夹为记录在raw_dir,它用于存储未处理的文件,从网络上下载的数据集原始文件会被存放到这里;
      • 另一个文件夹记录在processed_dir处理后的数据被保存到这里,以后从此文件夹下加载文件即可获得Data对象。
      • 注:raw_dirprocessed_dir是属性方法,我们可以自定义要使用的文件夹。
    • transform:函数类型,一个数据转换函数,它接收一个Data对象并返回一个转换后的Data对象。此函数在每一次数据获取过程中都会被执行。获取数据的函数首先使用此函数对Data对象做转换,然后才返回数据。此函数应该用于数据增广(Data Augmentation)。该参数默认值为None,表示不对数据做转换。
    • pre_transform:函数类型,一个数据转换函数,它接收一个Data对象并返回一个转换后的Data对象。此函数在Data对象被保存到文件前调用。因此它应该用于只执行一次的数据预处理。该参数默认值为None,表示不做数据预处理。
    • pre_filter:函数类型,一个检查数据是否要保留的函数,它接收一个Data对象,返回此Data对象是否应该被包含在最终的数据集中。此函数也在Data对象被保存到文件前调用。该参数默认值为None,表示不做数据检查,保留所有的数据。
  2. 划分数据集

    将数据集别用掩码(mask)划分为训练、验证和测试集三个部分。

*raw_file_names()

属性方法,返回一个数据集原始文件的文件名列表,数据集原始文件应该能在raw_dir文件夹中找到,否则调用download()函数下载文件到raw_dir文件夹。

*processed_file_names()

属性方法,返回一个存储处理过的数据的文件的文件名列表,存储处理过的数据的文件应该能在processed_dir文件夹中找到,否则调用process()函数对样本做处理,然后保存处理过的数据到processed_dir文件夹下的文件里。

*download()

根据定义的url属性下载数据集原始文件raw_dir文件夹。

*process()

调用读取数据函数,将数据包装成Data,然后处理数据,保存处理好的数据到processed_dir文件夹下的文件。

raw_dir()

属性方法,原始数据存储的文件夹路径,我们可以自定义要使用的文件夹。

processed_dir()

属性方法,处理后数据存储的文件夹路径,我们可以自定义要使用的文件夹。

__repr__()

这个函数对应repr(object)这个功能。意思是当需要显示一个对象在屏幕上时,将这个对象的属性或者是方法整理成一个可以打印输出的格式。python __repr__的作用

也就是说:

print(Planetoid)

会打印出:

Planetoid()

其他符号和设计

:->Optional[Callable]

参数+冒号+数据类型,如def __init__(self, root: str, name: str, split: str = "public",num_train_per_class: int = 20, num_val:int = 500,num_test: int = 1000, transform: Optional[Callable] = None,pre_transform: Optional[Callable] = None),以及函数+"->"+数据类型,如def raw_dir(self) -> str:

参考:Python函数参数中的冒号与箭头Python中typing模块与类型注解的使用方法

函数参数中的冒号是参数的类型建议符,告诉程序员希望传入的实参的类型。

函数后面跟着的箭头是函数返回值的类型建议符,用来说明该函数返回的值是什么类型。

类型建议符并非强制规定和检查,也就是说即使传入的实际参数与建议参数不符,也不会报错。

Optional[Callable]Optional表示这个参数可以为空或已经声明的类型,Callable表示是一个可调用类型的参数(函数等)。

@property

@property修饰的方法为属性方法,如:

@property
def raw_file_names(self) -> List[str]:
pass

参考[python @property的介绍与使用](https://zhuanlan.zhihu.com/p/64487092):

python的@property是python的一种装饰器,是用来修饰方法的。可以使用@property装饰器来创建只读属性@property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改

Planetoid类实例化流程

执行以下语句时,该类的实例化流程如下。

dataset = Planetoid(root='dataset/PlanetoidPubMed',transform=NormalizeFeatures())
data = dataset[0].to(device) #这一步才执行transform的函数
  1. 首先,检查数据原始文件是否已下载:
    • 检查self.raw_dir目录下是否存在raw_file_names()属性方法返回的每个文件,
    • 如有文件不存在,则调用download()方法执行原始文件下载。
  2. 其次,检查数据是否经过处理:
    • 首先,检查之前对数据做变换的方法:检查self.processed_dir目录下是否存在pre_transform.pt文件:
      • 如果存在,意味着之前进行过数据变换,接着需要加载该文件,以获取之前所用的数据变换的方法,并检查它与当前pre_transform参数指定的方法是否相同,
        • 如果不相同则会报出一个警告,“The pre_transform argument differs from the one used in ……”。
    • 其次,检查之前的样本过滤的方法:检查self.processed_dir目录下是否存在pre_filter.pt文件:
      • 如果存在,则加载该文件并获取之前所用的样本过滤的方法,并检查它与当前pre_filter参数指定的方法是否相同,
        • 如果不相同则会报出一个警告,“The pre_filter argument differs from the one used in ……”。
    • 接着,检查是否存在处理好的数据:检查self.processed_dir目录下是否存在self.processed_file_names属性方法返回的所有文件,如有文件不存在,则需要执行以下的操作:
      • 调用process()方法,进行数据处理。
      • 如果pre_transform参数不为None,则调用pre_transform()函数进行数据处理。
      • 如果pre_filter参数不为None,则进行样本过滤(此例子中不需要进行样本过滤,pre_filter参数为None)。
      • 保存处理好的数据到文件,文件存储在processed_paths()属性方法返回的文件路径。如果将数据保存到多个文件中,则返回的路径有多个。
        • processed_paths()属性方法是在基类(DataSet)中定义的,它对self.processed_dir文件夹与processed_file_names()属性方法的返回每一个文件名做拼接,然后返回。
      • 最后保存新的pre_transform.pt文件和pre_filter.pt文件,它们分别存储当前使用的数据处理方法和样本过滤方法。
  3. 保证有预处理的文件后,在self.data, self.slices = torch.load(self.processed_paths[0])时从预处理文件路径中加载预处理后的数据。
  4. 在执行data = dataset[0]时才调用选择的transform函数。

实验

现在让我们用twitter数据集来自定义一个数据类试试看吧。数据下载链接

数据示例:

index,o_place,d_place,year,month,day,cnt,o_lat,o_lon,d_lat,d_lon 196,37119,37025,2020,1,1,33,35.227,-80.827,35.41,-80.612 239,39123,39123,2020,1,1,7,41.529,-83.114,41.529,-83.114 462,13277,13277,2020,1,1,12,31.452,-83.508,31.462,-83.514 538,48201,48347,2020,1,1,4,29.795,-95.374,31.614,-94.649

import os.path as osp
import torch

from torch_geometric.data import (InMemoryDataset, download_url)
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import Data
#读取twitter数据函数
def read_twitter_data(raw_dir,raw_file_names):
data=[]
for raw_file in raw_file_names:
raw_file=osp.join(raw_dir,raw_file)
source_list = []
target_list = []
node_set = set()
with open(raw_file,'r',encoding='utf-8') as f:
i=0
for line in f.readlines():
if i==0:
i += 1
continue
index,o_place,d_place,year,month,day,cnt,o_lat,o_lon,d_lat,d_lon = line.split(',')
source_list.append(int(o_place))
target_list.append(int(d_place))
node_set.add(o_place)
node_set.add(d_place)
edge_index=torch.stack([torch.tensor(source_list,dtype=torch.long),torch.tensor(target_list,dtype=torch.long)],dim=0)
#我们这里没有节点特征也没有节点label,设置x为单位矩阵,y=None
x=torch.eye(len(node_set))
data=Data(x=x,edge_index=edge_index,y=None) #包装成Data类
return data

class TwitterDataset(InMemoryDataset):
url = 'https://github.com/ifwind/GNN_datawhale/blob/main/2020-1-1.txt'

def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super(TwitterDataset, self).__init__(root=root, transform=transform, pre_transform=pre_transform,
pre_filter=pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])#InMemoryDataset继承于Dataset类,processed_paths是Dataset类的属性,有self.processed_dir拼接得到

@property
def raw_dir(self):
return osp.join(self.root, 'raw')

@property
def processed_dir(self):
return osp.join(self.root, 'processed')

@property
def raw_file_names(self):
names = ['2020-1-1']
return ['{}.txt'.format(name) for name in names]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
for name in self.raw_file_names:
download_url('{}/{}'.format(self.url, name), self.raw_dir)

def process(self):
data = read_twitter_data(self.raw_dir,self.raw_file_names) #调用twitter数据读取函数
data = data if self.pre_transform is None else self.pre_transform(data)
torch.save(self.collate([data]), self.processed_paths[0])

def __repr__(self):
return '{}()'.format(self.name)

dataset = TwitterDataset(root='dataset/Twitter', transform=NormalizeFeatures())

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device)
print(data)
#Data(edge_index=[2, 18642], x=[2727, 2727])

参考文献

加载Planetoid数据集的源码

6-1-数据完整存于内存的数据集类.md

python __repr__的作用

[python @property的介绍与使用](https://zhuanlan.zhihu.com/p/64487092)

Python函数参数中的冒号与箭头

Python中typing模块与类型注解的使用方法