BERT相关——(6)BERT代码分析
BERT相关——(6)BERT代码分析
引言
上一篇介绍了如何利用HuggingFace的transformers从头开始预训练BERT模型,所使用的AutoModelForMaskedLM
函数可以实例化为transformers library中现有的masked language model中的模型类之一。这一篇将分析transformers中实现BERT模型相关的源码,以便我们可以设计自己的模型。
回顾一下从头开始预训练BERT模型的几个步骤,分别对应了各个模块的源码:
利用Tokenizer对语料分词——BertTokenizer;
重新配置模型——BertModel(其中又包括了各个子模块,如下图所示);
编写满足训练任务的处理代码:每个句子进行掩膜并组成句子对的正负样本集合以完成BERT训练的两个任务——数据预处理相关的类;
数据输入模型进行训练——数据加载入模型相关类如:Dataset、DataLoader、DataCollator。
这一篇主要分析前两步的代码。
Tokenizer对语料分词——BertTokenizer
BertTokenizer代码核心部分在中models/bert/tokenization_bert.py。
我们先来理清一下需要BertTokenizer干什么事情,再来看源码是怎么实现的,考虑了什么。
需求分析
BertTokenizer的目的是对语料分词,据此衍生出以下几个问题,也对应了该类需要满足的几个功能:
- 输入语料库(多个句子组成)或一个句子:加载全部训练数据、对训练数据进行处理、对单个句子进行处理(包括输入、输出);
- 要能对句子进行拆分:分词功能;
- 拆分尺度到哪一步呢?中英文的拆分策略差异?:字(char)、subword(介于 char 和 word之间,比如英文的词根 )和词(word)等;
- subword怎么拼接回word?:subword 拼接回word;
- 拆分结果要生成索引或者是id才能方便输入模型:根据训练集构建字典、根据id查词、根据词查id;
- 训练集中未出现的词怎么处理?:未出现词的映射(根据词查id时)、[UNK]字符
- 一些特殊字符"[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]";
- 模型保存、加载:BertTokenizer模型的保存、加载;
- 与BERT模型输入相关的:添加[CLS]等特殊字符、构建NSP任务的训练集、特殊字符的掩码。
源码
现在让我们来看一下具体的源码实现中是怎么实现上述需求的:
"""Tokenization classes for Bert.""" |
源码分析
BertTokenizer
是基于BasicTokenizer
和WordPieceTokenizer
的分词器:
- BasicTokenizer按标点、空格等分割句子,并处理是否统一小写,以及清理非法字符。其中:
- 对于中文字符,通过预处理(加空格)来按字分割;
- 同时可以通过never_split指定对某些词不进行分割;
- 这一步是可选的(默认执行)
do_basic_tokenize=True
。
- WordPieceTokenizer在词的基础上,进一步将词分解为子词(subword)。
- 找到字符串中存在vocab中的subword将word划分为subword;
- subword 介于 char 和 word 之间,既在一定程度保留了词的含义,又能够照顾到英文中单复数、时态导致的词表爆炸和未登录词的 OOV(Out-Of-Vocabulary)问题,将词根与时态词缀等分割出来,从而减小词表,也降低了训练难度;
- 例如,tokenizer 这个词就可以拆解为“token”和“##izer”两部分,注意后面一个词的“##”表示接在前一个词后面。
BertTokenizer 有以下常用方法:
- from_pretrained:从包含词表文件(vocab.txt)的目录中初始化一个分词器;
- tokenize:将文本(词或者句子)分解为子词列表;
- convert_tokens_to_ids:将子词列表转化为子词对应下标的列表;
- convert_ids_to_tokens :与上一个相反;
- convert_tokens_to_string:将 subword 列表按“##”拼接回词或者句子;
- encode:对于单个句子输入,分解词并加入特殊词形成“[CLS], x, [SEP]”的结构并转换为词表对应下标的列表;对于两个句子输入(多个句子只取前两个),分解词并加入特殊词形成“[CLS], x1, [SEP], x2, [SEP]”的结构并转换为下标列表;
- decode:可以将 encode 方法的输出变为完整句子。
BertModel
和 BERT 模型有关的代码主要写在/models/bert/modeling_bert.py
中,包含 BERT 模型的基本结构和基于它的微调模型等。这篇中我们主要先看Bert模型的本体-BertModel。
需求分析
让我们结合下面这张图来分析一下BERT模型的功能需求。
总的来说需要以下几个部分:
- 输入:
- 根据输入的id获取各个词的嵌入:token embedding;
- 用于NSP任务的segment embedding;
- 生成位置嵌入position embedding;
- encoder:
- 多头注意力机制模块;
- 残差模块+layer norm;
- 全连接层;
- 输出:
- 池化层,用于训练任务。
- 整体的模型、模型保存、加载。
因为每个部分代码都很长,所以分块分析源码。
先从整体模型来看。
BertModel
源码
|
源码分析
在初始化部分可以发现源码将BertModel划分成下图红框的三个模块(BertEmbedding、BertEncoder、BertPooler):
BertModel前向传播中利用到了以下参数:
- input_ids:经过 tokenizer 分词后的 subword 对应的下标列表;
- attention_mask:在 self-attention 过程中,这一块 mask 用于标记 subword 所处句子和 padding 的区别,将 padding 部分填充为 0;
- token_type_ids:标记 subword 当前所处句子(第一句/第二句/ padding);
- position_ids:标记当前词所在句子的位置下标;
- head_mask:用于将某些层的某些注意力计算无效化;
- inputs_embeds:如果提供了,那就不需要input_ids,跨过 embedding lookup 过程直接作为 Embedding 进入 Encoder 计算;
- encoder_hidden_states:这一部分在 BertModel 配置为 decoder 时起作用,将执行 cross-attention 而不是 self-attention;
- encoder_attention_mask:同上,在 cross-attention 中用于标记 encoder 端输入的 padding;
- past_key_values:这个参数貌似是把预先计算好的 K-V 乘积传入,以降低 cross-attention 的开销(因为原本这部分是重复计算);
- use_cache:将保存上一个参数并传回,加速 decoding;
- output_attentions:是否返回中间每层的 attention 输出;
- output_hidden_states:是否返回中间每层的输出;
- return_dict:是否按键值对的形式(ModelOutput 类,也可以当作 tuple 用)返回输出,默认为真。
注意,这里的 head_mask
对部分attention_score
无效化,和下文提到的注意力头剪枝(_prune_heads
)不同,仅仅把attention_score
给乘以这一矩阵。
此外,BertModel 还有以下的方法:
- get_input_embeddings:提取 embedding 中的 word_embeddings 即词向量部分;
- set_input_embeddings:为 embedding 中的 word_embeddings 赋值;
- _prune_heads:提供了将注意力头剪枝的函数,输入为
{layer_num: list of heads to prune in this layer}
的字典,可以将指定层的某些注意力头剪枝,直接对权重矩阵剪枝。
剪枝是一个复杂的操作,需要将保留的注意力头部分的 Wq、Kq、Vq 和拼接后全连接部分的权重拷贝到一个新的较小的权重矩阵(注意先禁止 grad 再拷贝),并实时记录被剪掉的头以防下标出错。具体参考BertAttention部分的prune_heads方法,下文会标注出来。
此外,还需要注意,encoder中的attention_mask是extended_attention_mask而不是BertModel前向传播输入的attention_mask。
接下来我们分块分析上面划分的三个子模块。
输入-BertEmbeddings
输入的嵌入包含三个部分求和得到:
源码
class BertEmbeddings(nn.Module): |
源码分析
- word_embeddings,上文中 subword 对应的嵌入。
- token_type_embeddings,用于表示当前词所在的句子,辅助区别句子与 padding、句子对间的差异
- position_embeddings,句子中每个词的位置嵌入,用于区别词的顺序。和 transformer 论文中的设计不同,这一块是训练出来的,而不是通过 Sinusoidal 函数计算得到的固定嵌入。一般认为这种实现不利于拓展性(难以直接迁移到更长的句子中)。
三个 embedding 不带权重相加,并通过一层 LayerNorm+dropout 后输出,其大小为(batch_size, sequence_length, hidden_size)。【为什么选择LayerNorm可参考:Transformer相关——normalization机制】
BertEncoder
为了方便理解,先把BertEncoder的层次结构图放上来,进一步拆分逐个分析。
BertEncoder
源码
class BertEncoder(nn.Module): |
源码分析
BertEncoder由多个BertLayer组成。
特别地,BertEncoder在前向传播时利用了 gradient checkpointing 技术以降低训练时的显存占用。
gradient checkpointing 即梯度检查点,通过减少保存的计算图节点压缩模型占用空间,但是在计算梯度的时候需要重新计算没有存储的值,参考论文《Training Deep Nets with Sublinear Memory Cost》,过程如下示意图 图:
在 BertEncoder 中,gradient checkpoint 是通过
torch.utils.checkpoint.checkpoint
实现的,可以参考:网络训练高效内存管理——torch.utils.checkpoint的使用
BertLayer
源码
class BertLayer(nn.Module): |
源码分析
BertLayer被拆分为:BertAttention(attention机制)、BertIntermediate(全连接+激活)、BertOutput(输出层)。
注意这里用apply_chunking_to_forward
函数将 input_tensors 分块成更小的输入张量部分,是一个节约显存的技术,(比如:11*768维度的拆分成2个chunk为11*384的张量进行计算),然后执行feed_forward_chunk
函数,apply_chunking_to_forward
函数输入参数列表为:forward_fn,chunk_size,chunk_dim,input_tensors
。
BertAttention
源码
class BertAttention(nn.Module): |
源码分析
find_pruneable_heads_and_indices
定位需要剪掉的 head,以及需要保留的维度下标 index:find_pruneable_heads_and_indices源码
prune_linear_layer
这里是直接剪枝权重,负责将 \(W^k/W^q/W^v\) 权重矩阵(连同 bias)中按照 index 保留没有被剪枝的维度后转移到新的矩阵。具体请查看prune_linear_layer源码。
BertSelfAttention
这个类才真正是做了attention操作,关于Encoder中的attention机制可以参考之前的博客:Transformer相关——(3)Attention机制
源码
class BertSelfAttention(nn.Module): |
源码分析
核心就是attention机制,有几个细节在这里总结一下:
hidden_size
和all_head_size
:一开始hidden_size 和 all_head_size 是一样的,做了剪枝操作(prune head)后all_head_size 变小。hidden_size 必须是 num_attention_heads 的整数倍:原因可查看之前的博客【定位词”
TIPS
“】transpose_for_scores
用来把hidden_size
拆成多个头输出的形状,并且将中间两维转置以进行矩阵相乘;关于
torch.einsum
爱因斯坦求和约定,可参考:einsum满足你一切需要:深度学习中的爱因斯坦求和约定对于不同的positional_embedding_type,有三种操作:
- absolute:默认值,不作处理;
- relative_key:对 key_layer 作处理,将positional_embedding和 key 矩阵相乘作为 key 相关的位置编码;
- relative_key_query:对 key 和 value 都与positional_embedding进行相乘以作为各自的位置编码。
attention_scores = attention_scores + attention_mask
:这里的attention_mask,不需要mask的地方为0(不改变值),需要mask的地方为一个很小的数(相加以后经过softmax趋近于0)。前面特别提到encoder中的attention_mask是extended_attention_mask而不是BertModel前向传播输入的attention_mask。这个extended_attention_mask
由get_extended_attention_mask
函数产生。目的是为了在decoder的时候对当前词后序出现的词进行掩码。又注意到,这个
get_extended_attention_mask
在调用时,用的是self.get_extended_attention_mask
:extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
也就是说它是BertModel内部的函数,但在上面的源码中我们并没有发现该函数的实现。
事实上,这是因为BertModel是BertPreTrainedModel的子类,而BertPreTrainedModel是PreTrainedModel的子类,PreTrainedModel又是ModuleUtilsMixin的子类,该函数就是在 ModuleUtilsMixin中实现的。
代码还为attention_probs做了dropout,这是原始Transformer论文中的实现,虽然源码中也提到了比较unusual。
BertSelfOutput
全连接+LayerNorm+dropout。
源码
class BertSelfOutput(nn.Module): |
BertIntermediate
全连接+ACT2FN激活函数。
源码
class BertIntermediate(nn.Module): |
BertOutput
和上面的BertSelfOuput一毛一样:全连接+LayerNorm+dropout。
源码
class BertSelfOutput(nn.Module): |
到这里BertEncoder终于结束了,接下来是输出部分。
BertPooler
源码
class BertPooler(nn.Module): |
源码分析
只是简单地取出了句子的第一个token,即[CLS]
对应的向量,然后过一个全连接层和一个激活函数后输出:(这一部分是可选的,根据add_pooling_layer
选择,因为pooling有很多不同的操作)。