BERT相关——(6)BERT代码分析

引言

上一篇介绍了如何利用HuggingFace的transformers从头开始预训练BERT模型,所使用的AutoModelForMaskedLM函数可以实例化为transformers library中现有的masked language model中的模型类之一。这一篇将分析transformers中实现BERT模型相关的源码,以便我们可以设计自己的模型。

回顾一下从头开始预训练BERT模型的几个步骤,分别对应了各个模块的源码:

  1. 利用Tokenizer对语料分词——BertTokenizer;

  2. 重新配置模型——BertModel(其中又包括了各个子模块,如下图所示);

  3. 编写满足训练任务的处理代码:每个句子进行掩膜并组成句子对的正负样本集合以完成BERT训练的两个任务——数据预处理相关的类;

  4. 数据输入模型进行训练——数据加载入模型相关类如:Dataset、DataLoader、DataCollator。

这一篇主要分析前两步的代码。

Tokenizer对语料分词——BertTokenizer

BertTokenizer代码核心部分在中models/bert/tokenization_bert.py

我们先来理清一下需要BertTokenizer干什么事情,再来看源码是怎么实现的,考虑了什么。

需求分析

BertTokenizer的目的是对语料分词,据此衍生出以下几个问题,也对应了该类需要满足的几个功能:

  1. 输入语料库(多个句子组成)或一个句子:加载全部训练数据、对训练数据进行处理、对单个句子进行处理(包括输入、输出);
  2. 要能对句子进行拆分:分词功能;
    • 拆分尺度到哪一步呢?中英文的拆分策略差异?:字(char)、subword(介于 char 和 word之间,比如英文的词根 )和词(word)等;
    • subword怎么拼接回word?:subword 拼接回word;
    • 拆分结果要生成索引或者是id才能方便输入模型:根据训练集构建字典、根据id查词、根据词查id;
    • 训练集中未出现的词怎么处理?:未出现词的映射(根据词查id时)、[UNK]字符
  3. 一些特殊字符"[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]";
  4. 模型保存、加载:BertTokenizer模型的保存、加载;
  5. 与BERT模型输入相关的:添加[CLS]等特殊字符、构建NSP任务的训练集、特殊字符的掩码。

源码

现在让我们来看一下具体的源码实现中是怎么实现上述需求的:

"""Tokenization classes for Bert."""
import collections
import os
import unicodedata
from typing import List, Optional, Tuple

from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...utils import logging

logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}

#预训练模型的字典
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
#……
}
}
#预训练模型的位置编码大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"bert-base-uncased": 512,
#……
}
#预训练模型的config
PRETRAINED_INIT_CONFIGURATION = {
"bert-base-uncased": {"do_lower_case": True},
#……
}

#加载字典
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
for index, token in enumerate(tokens):
token = token.rstrip("\n")
vocab[token] = index
return vocab

#空格分割
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens


class BertTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]", #特殊字符
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs
):
super().__init__(
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
never_split=never_split,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)

if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)#在分词的基础上将词划分为子词subword

@property
def do_lower_case(self):
return self.basic_tokenizer.do_lower_case

@property
def vocab_size(self):
return len(self.vocab)

def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):

# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
else:
split_tokens += self.wordpiece_tokenizer.tokenize(token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
#根据token查id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
#根据id查token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
#把subword拼接回word
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
#添加[CLS]等特殊字符、构建NSP任务的训练集
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: ``[CLS] X [SEP]``
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
cls = [self.cls_token_id]
sep = [self.sep_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
#特殊字符的掩码,包括单个句子的开头[CLS]和结尾[SEP]([CLS] X [SEP]),和句子对的特殊字符([CLS] A [SEP] B [SEP])
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""

if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)

if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
#句子对的掩码
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
pair mask has the following format:
::
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
#保存字典,token
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,)


class BasicTokenizer(object):
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents

def tokenize(self, text, never_split=None):
"""
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
text = self._clean_text(text)

# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if self.tokenize_chinese_chars:
text = self._tokenize_chinese_chars(text) #处理中文
orig_tokens = whitespace_tokenize(text)#按空格分隔返回tokens
split_tokens = []
for token in orig_tokens:
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))

output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
#从文本中去除重音符号
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
#分割标点符号
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1

return ["".join(x) for x in output]
#中文句子中每个中文词间都加上空格
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
#根据编码判断词是否为中文词
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True

return False
#去除一些无效字符
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xFFFD or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)

#在分词的基础上将词划分为子词subword
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""

def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word

def tokenize(self, text):
"""
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
tokenization using the given vocabulary.
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""

output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token) #处理没有出现过的词
continue

is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab: #找到字符串中存在vocab中的substr
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end

if is_bad:
output_tokens.append(self.unk_token) #处理没有出现过的词
else:
output_tokens.extend(sub_tokens)
return output_tokens

源码分析

BertTokenizer 是基于BasicTokenizerWordPieceTokenizer的分词器:

  • BasicTokenizer按标点、空格等分割句子,并处理是否统一小写,以及清理非法字符。其中:
    • 对于中文字符,通过预处理(加空格)来按字分割;
    • 同时可以通过never_split指定对某些词不进行分割;
    • 这一步是可选的(默认执行)do_basic_tokenize=True
  • WordPieceTokenizer在词的基础上,进一步将词分解为子词(subword)。
    • 找到字符串中存在vocab中的subword将word划分为subword;
    • subword 介于 char 和 word 之间,既在一定程度保留了词的含义,又能够照顾到英文中单复数、时态导致的词表爆炸和未登录词的 OOV(Out-Of-Vocabulary)问题,将词根与时态词缀等分割出来,从而减小词表,也降低了训练难度;
    • 例如,tokenizer 这个词就可以拆解为“token”和“##izer”两部分,注意后面一个词的“##”表示接在前一个词后面。

BertTokenizer 有以下常用方法:

  • from_pretrained:从包含词表文件(vocab.txt)的目录中初始化一个分词器;
  • tokenize:将文本(词或者句子)分解为子词列表;
  • convert_tokens_to_ids:将子词列表转化为子词对应下标的列表;
  • convert_ids_to_tokens :与上一个相反;
  • convert_tokens_to_string:将 subword 列表按“##”拼接回词或者句子;
  • encode:对于单个句子输入,分解词并加入特殊词形成“[CLS], x, [SEP]”的结构并转换为词表对应下标的列表;对于两个句子输入(多个句子只取前两个),分解词并加入特殊词形成“[CLS], x1, [SEP], x2, [SEP]”的结构并转换为下标列表;
  • decode:可以将 encode 方法的输出变为完整句子。

BertModel

和 BERT 模型有关的代码主要写在/models/bert/modeling_bert.py中,包含 BERT 模型的基本结构和基于它的微调模型等。这篇中我们主要先看Bert模型的本体-BertModel。

需求分析

让我们结合下面这张图来分析一下BERT模型的功能需求。

总的来说需要以下几个部分:

  1. 输入:
    • 根据输入的id获取各个词的嵌入:token embedding;
    • 用于NSP任务的segment embedding;
    • 生成位置嵌入position embedding;
  2. encoder:
    • 多头注意力机制模块;
    • 残差模块+layer norm;
    • 全连接层;
  3. 输出:
    • 池化层,用于训练任务。
  4. 整体的模型、模型保存、加载。

因为每个部分代码都很长,所以分块分析源码。

先从整体模型来看。

BertModel

源码

@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""

def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
# 输入的嵌入
self.embeddings = BertEmbeddings(config)
# encoder
self.encoder = BertEncoder(config)
# 输出,可选add_pooling_layer
self.pooler = BertPooler(config) if add_pooling_layer else None

self.init_weights()

def get_input_embeddings(self):
return self.embeddings.word_embeddings

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value

def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
#生成嵌入
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
#经过encoder
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask, #注意encoder中的attention_mask是extended_attention_mask
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 取last_hidden_state(最后一层输出)
sequence_output = encoder_outputs[0]
# 取第一个符号[CLS]作为整个句子的表示(self-attention下,句子中每个词都包括了整个句子的信息)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]#encoder_outputs[1:]表示encoder其他的输出,如past_key_values、hidden_states等

#包装输出,类似字典包括输出对应的键/值
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)

源码分析

在初始化部分可以发现源码将BertModel划分成下图红框的三个模块(BertEmbedding、BertEncoder、BertPooler):

BertModel前向传播中利用到了以下参数:

  • input_ids:经过 tokenizer 分词后的 subword 对应的下标列表;
  • attention_mask:在 self-attention 过程中,这一块 mask 用于标记 subword 所处句子和 padding 的区别,将 padding 部分填充为 0;
  • token_type_ids:标记 subword 当前所处句子(第一句/第二句/ padding);
  • position_ids:标记当前词所在句子的位置下标;
  • head_mask:用于将某些层的某些注意力计算无效化;
  • inputs_embeds:如果提供了,那就不需要input_ids,跨过 embedding lookup 过程直接作为 Embedding 进入 Encoder 计算;
  • encoder_hidden_states:这一部分在 BertModel 配置为 decoder 时起作用,将执行 cross-attention 而不是 self-attention;
  • encoder_attention_mask:同上,在 cross-attention 中用于标记 encoder 端输入的 padding;
  • past_key_values:这个参数貌似是把预先计算好的 K-V 乘积传入,以降低 cross-attention 的开销(因为原本这部分是重复计算);
  • use_cache:将保存上一个参数并传回,加速 decoding;
  • output_attentions:是否返回中间每层的 attention 输出;
  • output_hidden_states:是否返回中间每层的输出;
  • return_dict:是否按键值对的形式(ModelOutput 类,也可以当作 tuple 用)返回输出,默认为真。

注意,这里的 head_mask 对部分attention_score无效化,和下文提到的注意力头剪枝(_prune_heads)不同,仅仅把attention_score给乘以这一矩阵。

此外,BertModel 还有以下的方法:

  • get_input_embeddings:提取 embedding 中的 word_embeddings 即词向量部分;
  • set_input_embeddings:为 embedding 中的 word_embeddings 赋值;
  • _prune_heads:提供了将注意力头剪枝的函数,输入为{layer_num: list of heads to prune in this layer}的字典,可以将指定层的某些注意力头剪枝,直接对权重矩阵剪枝

剪枝是一个复杂的操作,需要将保留的注意力头部分的 Wq、Kq、Vq 和拼接后全连接部分的权重拷贝到一个新的较小的权重矩阵(注意先禁止 grad 再拷贝),并实时记录被剪掉的头以防下标出错。具体参考BertAttention部分的prune_heads方法,下文会标注出来。

此外,还需要注意,encoder中的attention_mask是extended_attention_mask而不是BertModel前向传播输入的attention_mask

接下来我们分块分析上面划分的三个子模块。

输入-BertEmbeddings

输入的嵌入包含三个部分求和得到:

源码

class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""

def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
# 层正则化
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") # 位置编码
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
persistent=False,
)

def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

seq_length = input_shape[1]

if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids) #使用绝对位置编码[1,2,3,4....]
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings

源码分析

  1. word_embeddings,上文中 subword 对应的嵌入。
  2. token_type_embeddings,用于表示当前词所在的句子,辅助区别句子与 padding、句子对间的差异
  3. position_embeddings,句子中每个词的位置嵌入,用于区别词的顺序。和 transformer 论文中的设计不同,这一块是训练出来的,而不是通过 Sinusoidal 函数计算得到的固定嵌入。一般认为这种实现不利于拓展性(难以直接迁移到更长的句子中)。

三个 embedding 不带权重相加,并通过一层 LayerNorm+dropout 后输出,其大小为(batch_size, sequence_length, hidden_size)。【为什么选择LayerNorm可参考:Transformer相关——normalization机制

BertEncoder

为了方便理解,先把BertEncoder的层次结构图放上来,进一步拆分逐个分析。

BertEncoder

源码
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
#多个BertLayer构成
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

next_decoder_cache = () if use_cache else None
#embedding逐层经过BertLayer
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None

if getattr(self.config, "gradient_checkpointing", False) and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)

return custom_forward
#利用 gradient checkpointing 技术以降低训练时的显存占用。梯度检查点,通过减少保存的计算图节点压缩模型占用空间,但是在计算梯度的时候需要重新计算没有存储的值,参考论文《Training Deep Nets with Sublinear Memory Cost》。
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,#取了第i层的head_mask
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
源码分析

BertEncoder由多个BertLayer组成。

特别地,BertEncoder在前向传播时利用了 gradient checkpointing 技术以降低训练时的显存占用。

gradient checkpointing 即梯度检查点,通过减少保存的计算图节点压缩模型占用空间,但是在计算梯度的时候需要重新计算没有存储的值,参考论文《Training Deep Nets with Sublinear Memory Cost》,过程如下示意图 图:

在 BertEncoder 中,gradient checkpoint 是通过 torch.utils.checkpoint.checkpoint 实现的,可以参考:网络训练高效内存管理——torch.utils.checkpoint的使用

BertLayer

源码
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config) #attention层
self.is_decoder = config.is_decoder #是否配置为decoder而不是encoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config) #全连接+激活
self.output = BertOutput(config) #输出层

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0] #获得context_layer 即 attention 矩阵与 value 矩阵的乘积

# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights,attention_probs

cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"

# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
cross_attn_past_key_value,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights

# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
#apply_chunking_to_forward函数将 input_tensors 分块成更小的输入张量部分然后执行forward_fn ,输入参数列表为:forward_fn,chunk_size,chunk_dim,input_tensors
#在这里执行了self.feed_forward_chunk函数,全连接+激活+输出层
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs

# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)

return outputs
#全连接+激活、输出
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
源码分析

BertLayer被拆分为:BertAttention(attention机制)、BertIntermediate(全连接+激活)、BertOutput(输出层)。

注意这里用apply_chunking_to_forward函数将 input_tensors 分块成更小的输入张量部分,是一个节约显存的技术,(比如:11*768维度的拆分成2个chunk为11*384的张量进行计算),然后执行feed_forward_chunk函数,apply_chunking_to_forward函数输入参数列表为:forward_fn,chunk_size,chunk_dim,input_tensors

BertAttention

源码
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config) #自注意力机制
self.output = BertSelfOutput(config) #输出层
self.pruned_heads = set()

def prune_heads(self, heads):
if len(heads) == 0:
return
#找到准备剪枝的头和索引
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
#直接剪枝权重
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
#输出(context_layer, attention_probs)
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
源码分析

find_pruneable_heads_and_indices定位需要剪掉的 head,以及需要保留的维度下标 index:find_pruneable_heads_and_indices源码

prune_linear_layer这里是直接剪枝权重,负责将 \(W^k/W^q/W^v\) 权重矩阵(连同 bias)中按照 index 保留没有被剪枝的维度后转移到新的矩阵。具体请查看prune_linear_layer源码

BertSelfAttention

这个类才真正是做了attention操作,关于Encoder中的attention机制可以参考之前的博客:Transformer相关——(3)Attention机制

源码
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
#线性变换
#这里 `key_layer/value_layer/query_laye`r 的形状为:(batch_size, num_attention_heads, sequence_length, attention_head_size);
# 一开始hidden_size 和 all_head_size 是一样的,做了剪枝操作(prune head)后all_head_size 变小。
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

self.is_decoder = config.is_decoder

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)

# 这个部分是 cross-attention部分,在bert中只用到了encoder的self-attention,这部分可以跳过。
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
#这里的 `transpose_for_scores` 用来把 `hidden_size` 拆成多个头输出的形状,并且将中间两维转置以进行矩阵相乘;
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
# 点积相关性
# (Q^Tk)
# 这里 `attention_scores` 的形状为:(batch_size, num_attention_heads, sequence_length, sequence_length),符合多个头单独计算获得的 attention map 形状。
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
#对于不同的positional_embedding_type,有三种操作:

#- absolute:默认值,这部分就不用处理;
#- relative_key:对 key_layer 作处理,将其与这里的positional_embedding和 key 矩阵相乘作为 key 相关的位置编码;
#- relative_key_query:对 key 和 value 都进行相乘以作为位置编码。
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility

if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) #torch.einsum是爱因斯坦求和约定
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
# 对点积结果进行缩放
# (Q^Tk)/\sqrt(d)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
# 注意这里的attention_mask,不需要mask的地方为0(不改变值),需要mask的地方为一个很小的数(相加以后经过softmax趋近于0)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# head_mask又出现了,如果不设置默认是全 1,在这里就不会起作用;
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
# (Q^Tk)\sqrt(d)V
# 形状为(batch_size, num_attention_heads, sequence_length, attention_head_size)
context_layer = torch.matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

context_layer = context_layer.view(*new_context_layer_shape)
#此时context_layer形状为(batch_size, sequence_length, hidden_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
源码分析

核心就是attention机制,有几个细节在这里总结一下:

  1. hidden_sizeall_head_size :一开始hidden_size 和 all_head_size 是一样的,做了剪枝操作(prune head)后all_head_size 变小。

  2. hidden_size 必须是 num_attention_heads 的整数倍:原因可查看之前的博客【定位词”TIPS“】

  3. transpose_for_scores 用来把 hidden_size 拆成多个头输出的形状,并且将中间两维转置以进行矩阵相乘;

  4. 关于torch.einsum爱因斯坦求和约定,可参考:einsum满足你一切需要:深度学习中的爱因斯坦求和约定

  5. 对于不同的positional_embedding_type,有三种操作:

    • absolute:默认值,不作处理;
    • relative_key:对 key_layer 作处理,将positional_embedding和 key 矩阵相乘作为 key 相关的位置编码;
    • relative_key_query:对 key 和 value 都与positional_embedding进行相乘以作为各自的位置编码。
  6. attention_scores = attention_scores + attention_mask:这里的attention_mask,不需要mask的地方为0(不改变值),需要mask的地方为一个很小的数(相加以后经过softmax趋近于0)。前面特别提到encoder中的attention_mask是extended_attention_mask而不是BertModel前向传播输入的attention_mask。这个extended_attention_maskget_extended_attention_mask函数产生。目的是为了在decoder的时候对当前词后序出现的词进行掩码。

  7. 又注意到,这个get_extended_attention_mask在调用时,用的是self.get_extended_attention_mask

    extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)

    也就是说它是BertModel内部的函数,但在上面的源码中我们并没有发现该函数的实现。

    事实上,这是因为BertModel是BertPreTrainedModel的子类,而BertPreTrainedModel是PreTrainedModel的子类,PreTrainedModel又是ModuleUtilsMixin的子类,该函数就是在 ModuleUtilsMixin中实现的。

  8. 代码还为attention_probs做了dropout,这是原始Transformer论文中的实现,虽然源码中也提到了比较unusual。

BertSelfOutput

全连接+LayerNorm+dropout。

源码
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states

BertIntermediate

全连接+ACT2FN激活函数

源码
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]#ACT2FN激活函数
else:
self.intermediate_act_fn = config.hidden_act

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

BertOutput

和上面的BertSelfOuput一毛一样:全连接+LayerNorm+dropout。

源码
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states

到这里BertEncoder终于结束了,接下来是输出部分。

BertPooler

源码

class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output

源码分析

只是简单地取出了句子的第一个token,即[CLS]对应的向量,然后过一个全连接层和一个激活函数后输出:(这一部分是可选的,根据add_pooling_layer选择,因为pooling有很多不同的操作)。

参考文献

3.1-如何实现一个BERT.md

NLP预训练模型2 -- BERT详解和源码分析