BERT实战——(6)生成任务-摘要生成
BERT实战——(6)生成任务-摘要生成
引言
这一篇将介绍如何使用 🤗 Transformers代码库中的模型来解决生成任务中的摘要生成问题。
任务介绍
摘要生成,用一些精炼的话(摘要)来概括整片文章的大意,用户通过读文摘就可以了解到原文要表达。
主要分为以下几个部分:
- 数据加载;
- 数据预处理;
- 微调预训练模型:使用transformer中的
Seq2SeqTrainer
接口对预训练模型进行微调(注意这里是Seq2SeqTrainer
接口,之前的任务都是调用Trainer
接口)。
前期准备
安装以下库:
pip install datasets transformers rouge-score nltk |
数据加载
数据集介绍
我们使用XSum dataset数据集,其中包含了多篇BBC的文章和一句对应的摘要。
加载数据
该数据的加载方式在transformers库中进行了封装,我们可以通过以下语句进行数据加载:
from datasets import load_dataset |
给定一个数据切分的key(train、validation或者test)和下标即可查看数据。
raw_datasets["train"][0] |
下面的函数将从数据集里随机选择几个例子进行展示:
import datasets |
show_random_elements(raw_datasets["train"][1:]) |
document | id | summary | |
---|---|---|---|
0 | Media playback is unsupported on your device18 December 2014 Last updated at 10:28 GMThas successfully tackled poverty over the last four decades by drawing on its rich natural resources.to the World Bank, some 49% of Malaysians in 1970 were extremely poor, and that figure has been reduced to 1% today. However, the government's next challenge is to help the lower income group to move up to the middle class, the bank says.Zahau, the World Bank's Southeast Asia director, spoke to the BBC's Jennifer Pak. | 30530533 | In Malaysia the "aspirational" low-income part of the population is helping to drive economic growth through consumption, according to the World Bank. |
数据预处理
在将数据喂入模型之前,我们需要对数据进行预处理。
仍然是两个数据预处理的基本流程:
- 分词;
- 转化成对应任务输入模型的格式;
Tokenizer
用于上面两步数据预处理工作:Tokenizer
首先对输入进行tokenize,然后将tokens转化为预模型中需要对应的token ID,再转化为模型需要的输入格式。
初始化Tokenizer
之前的博客已经介绍了一些Tokenizer的内容,并做了Tokenizer分词的示例,这里不再重复。use_fast=True
指定使用fast版本的tokenizer。
我们使用t5-small
模型checkpoint来做该任务。
from transformers import AutoTokenizer |
转化成对应任务输入模型的格式
模型的输入为待翻译的句子。
注意:为了给模型准备好翻译的targets,使用as_target_tokenizer
来为targets设置tokenizer:
with tokenizer.as_target_tokenizer(): |
如果使用的是T5预训练模型的checkpoints(比如我们这里用的t5-small),需要对特殊的前缀进行检查。T5使用特殊的前缀("summarize: "
)来告诉模型具体要做的任务,具体前缀例子如下:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]: |
现在我们可以把上面的内容放在一起组成预处理函数preprocess_function
。对样本进行预处理的时候,使用truncation=True
参数来确保超长文本被截断。默认情况下,对与比较短的句子会自动padding。max_input_length
控制了输入文本的长度,max_target_length
控制了摘要的最长长度。
max_input_length = 1024 |
以上的预处理函数可以处理一个样本,也可以处理多个样本exapmles。如果是处理多个样本,则返回的是多个样本被预处理之后的结果list。
接下来使用map函数对数据集datasets里面三个样本集合的所有样本进行预处理,将预处理函数preprocess_function
应用到(map)所有样本上。参数batched=True
可以批量对文本进行编码。这是为了充分利用前面加载fast_tokenizer的优势,它将使用多线程并发地处理批中的文本。
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True) |
微调预训练模型
数据已经准备好了,我们需要下载并加载预训练模型,然后微调预训练模型。
加载预训练模型
做seq2seq任务,那么需要一个能解决这个任务的模型类。我们使用AutoModelForSeq2SeqLM
这个类。
和之前几篇博客提到的加载方式相同不再赘述。
from transformers import AutoModelForSeq2SeqLM, |
设定训练参数
为了能够得到一个Seq2SeqTrainer
训练工具,我们还需要训练的设定/参数 Seq2SeqTrainingArguments
。这个训练设定包含了能够定义训练过程的所有属性。
由于数据集比较大,Seq2SeqTrainer
训练时会同时不断保存模型,我们用save_total_limit=3
参数控制至多保存3个模型。
from transformers import Seq2SeqTrainingArguments |
数据收集器data collator
接下来需要告诉Trainer
如何从预处理的输入数据中构造batch。我们使用数据收集器DataCollatorForSeq2Seq
,将经预处理的输入分batch再次处理后喂给模型。
from transformers import DataCollatorForSeq2Seq |
定义评估方法
我们使用'rouge'
指标,利用metric.compute
计算该指标对模型进行评估。
使用metric.compute
对比predictions和labels,从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子:
fake_preds = ["hello there", "general kenobi"] |
将模型预测送入评估之前,还需要写postprocess_text
函数做一些数据后处理。
nltk是一个自然语言处理的python工具包,我们这里用到了其中一个按句子分割的函数nltk.sent_tokenize()
。
import nltk |
开始训练
将数据/模型/参数传入Trainer
即可:
from transformers import Seq2SeqTrainer |
调用train
方法开始训练:
trainer.train() |