BERT实战——(5)生成任务-机器翻译
BERT实战——(5)生成任务-机器翻译
引言
之前的分别介绍了使用 🤗 Transformers代码库中的模型开展one-class任务(文本分类、多选问答问题)、class for each token任务(序列标注)以及copy from input任务(抽取式问答)。
这一篇以及下一篇将介绍如何使用 🤗 Transformers代码库中的模型来解决general sequence任务(关于什么是生成序列任务,回看之前的博客,定位词:general sequence)。这一篇为解决生成任务中的机器翻译问题。
任务介绍
翻译任务,把一种语言信息转变成另一种语言信息。是典型的seq2seq任务,输入为一个序列,输出为不固定长度(由机器自行学习生成的序列应该多长)的序列。
比如输入一句中文,翻译为英文:
输入:我爱中国。 |
主要分为以下几个部分:
- 数据加载;
- 数据预处理;
- 微调预训练模型:使用transformer中的
Seq2SeqTrainer
接口对预训练模型进行微调(注意这里是Seq2SeqTrainer
接口,之前的任务都是调用Trainer
接口)。
前期准备
安装以下库:
pip install datasets transformers sacrebleu sentencepiece |
数据加载
数据集介绍
我们使用WMT dataset数据集。这是翻译任务最常用的数据集之一。其中包括English/Romanian双语翻译。
加载数据
该数据的加载方式在transformers库中进行了封装,我们可以通过以下语句进行数据加载:
from datasets import load_dataset |
给定一个数据切分的key(train、validation或者test)和下标即可查看数据。
raw_datasets["train"][0] |
下面的函数将从数据集里随机选择几个例子进行展示:
import datasets |
show_random_elements(raw_datasets["train"]) |
translation | |
---|---|
0 | {'en': 'The Bulgarian gymnastics team won the gold medal at the traditional Grand Prix series competition in Thiais, France, which wrapped up on Sunday (March 30th).', 'ro': 'Echipa bulgară de gimnastică a câştigat medalia de aur la tradiţionala competiţie Grand Prix din Thiais, Franţa, care s-a încheiat duminică (30 martie).'} |
1 | {'en': 'Being on that committee, however, you will know that this was a very hot topic in negotiations between Norway and some Member States.', 'ro': 'Totuşi, făcând parte din această comisie, ştiţi că acesta a fost un subiect foarte aprins în negocierile dintre Norvegia şi unele state membre.'} |
2 | {'en': 'The overwhelming vote shows just this.', 'ro': 'Ceea ce demonstrează şi votul favorabil.'} |
3 | {'en': '[Photo illustration by Catherine Gurgenidze for Southeast European Times]', 'ro': '[Ilustraţii foto de Catherine Gurgenidze pentru Southeast European Times]'} |
4 | {'en': '(HU) Mr President, today the specific text of the agreement between the Hungarian Government and the European Commission has been formulated.', 'ro': '(HU) Domnule președinte, textul concret al acordului dintre guvernul ungar și Comisia Europeană a fost formulat astăzi.'} |
数据预处理
在将数据喂入模型之前,我们需要对数据进行预处理。
仍然是两个数据预处理的基本流程:
- 分词;
- 转化成对应任务输入模型的格式;
Tokenizer
用于上面两步数据预处理工作:Tokenizer
首先对输入进行tokenize,然后将tokens转化为预模型中需要对应的token ID,再转化为模型需要的输入格式。
初始化Tokenizer
之前的博客已经介绍了一些Tokenizer的内容,并做了Tokenizer分词的示例,这里不再重复。use_fast=True
指定使用fast版本的tokenizer。我们使用已经训练好的Helsinki-NLP/opus-mt-en-ro
checkpoint来做翻译任务。
from transformers import AutoTokenizer |
以使用的mBART模型为例,需要正确设置source语言和target语言。如果翻译的是其他双语语料,请查看这里进行配置:
if "mbart" in model_checkpoint: |
转化成对应任务输入模型的格式
模型的输入为待翻译的句子。
注意:为了给模型准备好翻译的targets,使用as_target_tokenizer
来为targets设置tokenizer:
with tokenizer.as_target_tokenizer(): |
如果使用的是T5预训练模型的checkpoints,需要对特殊的前缀进行检查。T5使用特殊的前缀来告诉模型具体要做的任务("translate English to Romanian: "
),具体前缀例子如下:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]: |
现在我们可以把上面的内容放在一起组成预处理函数preprocess_function
。对样本进行预处理的时候,使用truncation=True
参数来确保超长文本被截断。默认情况下,对与比较短的句子会自动padding。
max_input_length = 128 |
以上的预处理函数可以处理一个样本,也可以处理多个样本exapmles。如果是处理多个样本,则返回的是多个样本被预处理之后的结果list。
接下来使用map函数对数据集datasets里面三个样本集合的所有样本进行预处理,将预处理函数preprocess_function
应用到(map)所有样本上。参数batched=True
可以批量对文本进行编码。这是为了充分利用前面加载fast_tokenizer的优势,它将使用多线程并发地处理批中的文本。
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True) |
微调预训练模型
数据已经准备好了,我们需要下载并加载预训练模型,然后微调预训练模型。
加载预训练模型
做seq2seq任务,那么需要一个能解决这个任务的模型类。我们使用AutoModelForSeq2SeqLM
这个类。
和之前几篇博客提到的加载方式相同不再赘述。
from transformers import AutoModelForSeq2SeqLM, |
设定训练参数
为了能够得到一个Seq2SeqTrainer
训练工具,我们还需要训练的设定/参数 Seq2SeqTrainingArguments
。这个训练设定包含了能够定义训练过程的所有属性。
由于数据集比较大,Seq2SeqTrainer
训练时会同时不断保存模型,我们用save_total_limit=3
参数控制至多保存3个模型。
from transformers import Seq2SeqTrainingArguments |
数据收集器data collator
接下来需要告诉Trainer
如何从预处理的输入数据中构造batch。我们使用数据收集器DataCollatorForSeq2Seq
,将经预处理的输入分batch再次处理后喂给模型。
from transformers import DataCollatorForSeq2Seq |
定义评估方法
我们使用'bleu'
指标,利用metric.compute
计算该指标对模型进行评估。
metric.compute
对比predictions和labels,从而计算得分。predictions和labels都需要是一个list。具体格式见下面的例子:
fake_preds = ["hello there", "general kenobi"] |
将模型预测送入评估之前,还需要写postprocess_text
函数做一些数据后处理:
import numpy as np |
开始训练
将数据/模型/参数传入Trainer
即可:
from transformers import Seq2SeqTrainer |
调用train
方法开始训练:
trainer.train() |