BERT相关——(8)BERT-based Model代码分析

引言

上一篇提到如何利用BERT开展下游任务,以及fine tuning的方法。BertModel的输出了每个句子每个词的embedding,我们在Bert模型后面接上不同的任务就可以构建不同的模型。

HuggingFace的transformers库封装好了各个任务最简易的API,帮助我们快速开始。

实现了以下几个任务:

  1. BertForPreTraining
  2. BertForSequenceClassification
  3. BertForMultiChoice
  4. BertForTokenClassification
  5. BertForQuestionAnswering

如果我们想在其中增加一些模块,比如LSTM、CRF等优化模型,我们可以仿造这些封装好的API的写法,在这篇博客的最后将总结一下如何基于Bert模型进行扩展完成NLP任务

下面来分析各个API的源码。

BertPreTrainedModel基类

HuggingFace的transformers库中基于 BERT 的模型都是基于BertPreTrainedModel这一抽象基类的,而后者则基于一个更大的基类PreTrainedModel。这里我们关注BertPreTrainedModel的功能:

BertPreTrainedModel用于初始化模型权重,同时维护继承自PreTrainedModel的一些标记身份或者加载模型时的类变量。

class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""

config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

下面,首先从预训练模型开始分析。

BertForPreTraining相关

原始Bert的预训练模型包括两个训练任务:

  • Masked Language Model(MLM):在句子中随机用[MASK]替换一部分单词,然后将句子传入 BERT 中得到每一个单词的embedding,最终用[MASK]的embedding预测该位置的正确单词,这一任务旨在训练模型根据上下文理解单词的意思
  • Next Sentence Prediction(NSP):将句子对 A 和 B 输入 BERT,使用[CLS]的embedding进行预测 B 是否 A 的下一句,这一任务旨在训练模型理解预测句子间的关系。

需求分析

根据上面两个任务,显然我们有以下几个需求需要满足:

  1. 初始化一个BertModel;

  2. 实现分别完成这两个训练任务的代码;(如果我们只想训练其中一个训练任务怎么办呢?->完成单个训练任务的模型)
  3. 训练的loss function;
  4. 模型保存;

BertForPreTraining相关类的实现逻辑(HuggingFace)

HuggingFace的transformers库提供了对两个目标都进行预训练、以及只对其中一个任务进行预训练的Bert预训练模型:

  1. BertForPreTraining:进行MLM和NSP两个任务的预训练;
  2. BertForMaskedLM:只进行 MLM 任务的预训练;
  3. BertLMHeadModel:这个和上一个的区别在于,这一模型是作为 decoder 运行的版本
  4. BertForNextSentencePrediction:只进行 NSP 任务的预训练。

实现逻辑封装如下图所示:

BertForPreTraining

首先是完成两个训练目标的预训练模型BertForPreTraining。

调用案例

from transformers import BertTokenizer, BertForPreTraining
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
prediction_logits = outputs.prediction_logits
seq_relationship_logits = outputs.seq_relationship_logits

源码

@add_start_docstrings(
"""
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
""",
BERT_START_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bert = BertModel(config) #默认`add_pooling_layer=True`,会提取`[CLS]`对应的输出用于 NSP 任务
self.cls = BertPreTrainingHeads(config) #负责MLM和NSP任务的模块

self.init_weights()

def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
Example::
>>> from transformers import BertTokenizer, BertForPreTraining
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits
>>> seq_relationship_logits = outputs.seq_relationship_logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) #MLM和NSP任务

total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) #MLM任务loss
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) #NSP任务loss
total_loss = masked_lm_loss + next_sentence_loss

if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
#输出包装
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

源码分析

BertForPreTraining类包括BertModel、BertPreTrainingHeads、loss计算三个主要的部分。

  1. BertModel:BertModel在上一篇博客中已经详细分析了(其中,add_pooling_layer=True为默认值,BertModel中会加入BertPooler层,即会提取[CLS]对应的输出用于 NSP 任务);
  2. BertPreTrainingHeads:负责MLM和NSP任务的预测模块(下面继续分析);
  3. loss计算:包括MLM任务的loss和NSP任务的loss,两个任务的本质都是分类,所以选择CrossEntropyLoss作为损失函数,直接将两部分loss相加作为最终的loss。
    • labels:形状为[batch_size, seq_length] ,代表 MLM 任务的标签,注意这里对于原本未被遮盖的词设置为 -100,被遮盖词才会有它们对应的 id,和数据预处理的时候是反过来的。
      • 例如,原始句子是I want to [MASK] an apple,把单词eat给遮住了输入模型,对应的label设置为[-100, -100, -100, 【eat对应的id】, -100, -100];
      • 为什么要设置为 -100 而不是其他数?因为torch.nn.CrossEntropyLoss默认的ignore_index=-100,也就是说对于标签为 100 的类别输入不会计算 loss。
    • next_sentence_label:0 和 1 的二分类标签,0表示两个句子是上下句关系,1表示两个句子是随机拼接在一起的。

BertPreTrainingHeads-用于MLM任务和NSP任务

源码
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config) #用于MLM任务
self.seq_relationship = nn.Linear(config.hidden_size, 2) #用于NSP任务

def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
源码分析

BertPreTrainingHeads包括BertLMPredictionHead用于MLM任务,一个线性层nn.Linear用于NSP任务。

注意:此处用于NSP任务的线性层nn.Linear与下文中BertOnlyNSPHead类(该类本质也是一层nn.Linear,只是进行了封装)功能相同。

BertLMPredictionHead-用于MLM任务

源码

class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config) #MLM任务

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)#对每个mask进行分类

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states

源码分析

用于预测[MASK]位置的输出在每个词作为类别的分类输出,注意到:

  • 该类重新初始化了一个全 0 向量作为预测权重的 bias
  • 该类的输出形状为[batch_size, seq_length, vocab_size],即预测每个句子每个词是什么类别的概率值(注意这里没有做 softmax,因为CrossEntropyLoss内部包括了softmax的计算操作);
  • BertPredictionHeadTransform,用来完成一些线性变换:
BertPredictionHeadTransform-用于MLM任务线性变换
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act] #激活函数
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states

源码分析

对输入的embedding进行线性变换、激活和层归一化。

BertForMaskedLM

BertForMaskedLM只完成MLM任务,事实上后续的RoBERTa、ALBERT、spanBERT等模型都移去了NSP任务(消除NSP损失在下游任务的性能上能够与原始BERT持平或略有提高)

调用案例

from transformers import BertTokenizer, BertForMaskedLM
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
prediction_logits = outputs.prediction_logits
seq_relationship_logits = outputs.logits

源码

@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)

if config.is_decoder:
logger.warning(
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
)

self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config) #MLM任务封装BertLMPredictionHead

self.init_weights()

def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]

# add a dummy token
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)

return {"input_ids": input_ids, "attention_mask": attention_mask}

BertOnlyMLMHead-封装BertLMPredictionHead

class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)

def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores

源码分析

仍然是分三步:

  1. 初始化BertModel;
  2. 完成MLM任务;
  3. loss计算:显然只有MLM的loss。

BertLMHeadModel——next token任务(deocder版本)

与上面的BertForMaskedLM模型不同,BertLMHeadModel模型是decoder版本,也就是只能利用上文而不能利用下文,任务修改为next token prediction。

调用案例

from transformers import BertTokenizer, BertLMHeadModel, BertConfig
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained("bert-base-uncased")
config.is_decoder = True
model = BertLMHeadModel.from_pretrained('bert-base-uncased', config=config)
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
prediction_logits = outputs.logits

源码

@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)

if not config.is_decoder:
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")

self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config) #封装BertLMPredictionHead类

self.init_weights()

def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> config.is_decoder = True
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output

return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]

return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}

def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

源码分析

仍然是分三步:

  1. 初始化BertModel;
  2. 完成LM任务:注意这里做的是下一个token预测的任务;
  3. loss计算:计算下一个token预测结果的loss。

BertForNextSentencePrediction

只做NSP任务。

调用案例

from transformers import BertTokenizer, BertForNextSentencePrediction
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
outputs = model(**encoding, labels=torch.LongTensor([1]))
logits = outputs.logits
assert logits[0, 0] < logits[0, 1] # next sentence was random

源码

@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class BertForNextSentencePrediction(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config) #封装用于NSP的线性层

self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
Returns:
Example::
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
"""

if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

pooled_output = outputs[1]

seq_relationship_scores = self.cls(pooled_output)

next_sentence_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))

if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output

return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

BertOnlyNSPHead


class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)

def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score

源码分析

仍然是分三步:

  1. 初始化BertModel;
  2. 完成NSP任务;
  3. loss计算:显然只有NSP的loss。

接下来介绍HuggingFace-Transformers库中实现的四种 Fine-tune 模型,基本都是分类任务,与上一篇博客相对应,不包括seq2seq模型。

  1. one class-BertForSequenceClassification
  2. one class-BertForMultipleChoice
  3. class for each token-BertForTokenClassification
  4. copy from input-BertForQuestionAnswering

one class-BertForSequenceClassification

用于句子分类(也可以是回归)任务,比如 GLUE benchmark 的各个任务。

输入为句子(对),输出为单个分类标签。

调用案例

from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.models.bert.modeling_bert import BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = BertForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc")

classes = ["not paraphrase", "is paraphrase"] #是否是改述

sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"

# The tokekenizer will automatically add any model specific separators (i.e. <CLS> and <SEP>) and tokens to the sequence, as well as compute the attention masks.
paraphrase = tokenizer(sequence_0, sequence_2, return_tensors="pt")
not_paraphrase = tokenizer(sequence_0, sequence_1, return_tensors="pt")

paraphrase_classification_logits = model(**paraphrase).logits
not_paraphrase_classification_logits = model(**not_paraphrase).logits

paraphrase_results = torch.softmax(paraphrase_classification_logits, dim=1).tolist()[0]
not_paraphrase_results = torch.softmax(not_paraphrase_classification_logits, dim=1).tolist()[0]

# Should be paraphrase
for i in range(len(classes)):
print(f"{classes[i]}: {int(round(paraphrase_results[i] * 100))}%")

# Should not be paraphrase
for i in range(len(classes)):
print(f"{classes[i]}: {int(round(not_paraphrase_results[i] * 100))}%")

源码

@add_start_docstrings(
"""
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
BERT_START_DOCSTRING,
)
class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config

self.bert = BertModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

源码分析

分三步:

  1. 初始化BertModel(有 pooling);
  2. 分类任务:过一个 dropout,接一个线性层输出分类;
  3. loss计算:分类任务loss。
    • 在前向传播时,需要传入labels输入:

      • 如果初始化的num_labels=1,默认为回归任务,使用 MSELoss;

      • 否则认为是分类任务。

one class-BertForMultipleChoice

用于多项选择,如 RocStories/SWAG 任务。

输入为一组分次输入的句子,输出为选择某一句子的单个标签。

结构上与句子分类相似,不过线性层输出维度为 1,即每次需要将每个样本的多个句子的输出拼接起来作为每个样本的预测分数。

实际上,具体操作时是把每个 batch 的多个句子一同放入的,所以一次处理的输入为[batch_size, num_choices]数量的句子,因此相同 batch 大小时,比句子分类等任务需要更多的显存,在训练时需要小心。

调用示例

from transformers.models.bert.tokenization_bert import BertTokenizer
from transformers.models.bert.modeling_bert import BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForMultipleChoice.from_pretrained("bert-base-uncased")
choices = ["Hello, my dog is cute", "Hello, my cat is pretty"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # 形状为[1, 2, 7]
labels = torch.tensor(1).unsqueeze(0)
outputs = model(input_ids, labels=labels)

源码

@add_start_docstrings(
"""
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
BERT_START_DOCSTRING,
)
class BertForMultipleChoice(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bert = BertModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1)

self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)

if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

源码分析

分三步:

  1. 初始化BertModel(有 pooling);
  2. 分类任务:过一个 dropout,接一个线性层输出分类;
  3. loss计算:分类任务loss。
    • 在前向传播时,需要传入labels输入:[batch_size,num_choices]

class for each token-BertForTokenClassification

用于序列标注(词分类),如 NER 任务。

输入为单个句子文本,输出为每个 token 对应的类别标签。

调用案例

from transformers import BertForTokenClassification, BertTokenizer
import torch

model = BertForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

label_list = [
"O", # Outside of a named entity
"B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity
"I-MISC", # Miscellaneous entity
"B-PER", # Beginning of a person's name right after another person's name
"I-PER", # Person's name
"B-ORG", # Beginning of an organisation right after another organisation
"I-ORG", # Organisation
"B-LOC", # Beginning of a location right after another location
"I-LOC" # Location
]

sequence = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge."

# Bit of a hack to get the tokens with the special tokens
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sequence)))
inputs = tokenizer.encode(sequence, return_tensors="pt")

outputs = model(inputs).logits
predictions = torch.argmax(outputs, dim=2)

for token, prediction in zip(tokens, predictions[0].numpy()):
print((token, model.config.id2label[prediction]))

源码

@add_start_docstrings(
"""
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"] #去掉pooler层不警报

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.bert = BertModel(config, add_pooling_layer=False) #不需要pooler层
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

源码分析

分三步:

  1. 初始化BertModel,由于需要用到每个 token对应的输出而不只是某几个,所以这里的BertModel不用加入 pooling 层add_pooling_layer=False),同时,这里将_keys_to_ignore_on_load_unexpected这一个类参数设置为[r"pooler"]也就是在加载模型时对于出现不需要的权重不发生报错。
  2. 分类任务:过一个 dropout,接一个线性层输出分类;
  3. loss计算:每个词的分类任务loss。
    • 在前向传播时,需要传入labels输入:[batch_size,num_labels]

copy from input-BertForQuestionAnswering

用于解决问答任务,例如 SQuAD 任务。

输入为问题 +(对于 BERT 只能是一个)回答组成的句子对,输出为起始位置\(s\)和结束位置\(e\)用于标出回答中的具体文本。

这里需要两个输出,即对起始位置的预测和对结束位置的预测,两个输出的长度都和句子长度一样,分别选择最大的预测值对应的下标作为预测的位置。

调用案例

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")
model = AutoModelForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

text = "🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch."

questions = [
"How many pretrained models are available in 🤗 Transformers?",
"What does 🤗 Transformers provide?",
"🤗 Transformers provides interoperability between which frameworks?",
]

for question in questions:
inputs = tokenizer(question, text, add_special_tokens=True, return_tensors="pt")
input_ids = inputs["input_ids"].tolist()[0]
outputs = model(**inputs)
answer_start_scores = outputs.start_logits
answer_end_scores = outputs.end_logits
answer_start = torch.argmax(
answer_start_scores
) # Get the most likely beginning of answer with the argmax of the score
answer_end = torch.argmax(answer_end_scores) + 1 # Get the most likely end of answer with the argmax of the score
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
print(f"Question: {question}")
print(f"Answer: {answer}")

源码

@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
BERT_START_DOCSTRING,
)
class BertForQuestionAnswering(BertPreTrainedModel):

_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.bert = BertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]

logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()

total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index) #对超出句子长度的非法 label,会将其压缩(torch.clamp_)到合理范围。
end_positions = end_positions.clamp(0, ignored_index)#对超出句子长度的非法 label,会将其压缩(torch.clamp_)到合理范围。

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

源码分析

分三步:

  1. 初始化BertModel(无 pooling);
  2. 分类任务:过一个 dropout,接一个线性层输出分类;
  3. loss计算:预测答案起始位置loss和预测答案终止位置的loss,本质都是分类任务。
    • 在前向传播时,需要传入start_positions、end_positions;
    • 对超出句子长度的非法 label,会将其压缩(torch.clamp_)到合理范围。

使用Bert模型方法总结

可以看到上面的源码分析,总结一下我们都可以分成三步走:

  1. 初始化BertModel,注意需不需要加pooler层,用参数add_pooling_layer控制,默认为True;
  2. 根据任务自行定制对BertModel输出进行处理的层,比如Dropout、Linear,还可以再接上一些复杂模型,如LSTM、CNN等,此时需要注意层之间的输入、输出维度;
  3. loss计算,同样根据任务进行定义,forward函数往往需要传入ground truth。

再补充第0步为继承BertPreTrainedModel,以及第4步输出需要的内容,于是我们可以得到下面这个Bert-based模型改写框架:

class BertForQuestionAnswering(BertPreTrainedModel):
#取决于任务是否需要pooler层
_keys_to_ignore_on_load_unexpected = [r"pooler"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
#用add_pooling_layer控制BertModel是否最后一层包括Pooler层
self.bert = BertModel(config, add_pooling_layer=False)

self.mylayer = ....#nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()

def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

next_layer_input = outputs#[0]

logits1 = self.mylayer(next_layer_input)
logits2 = ....
total_loss = None
loss_fct = CrossEntropyLoss()
loss1 = loss_fct(logits1, groud_truth1)
loss2 = loss_fct(logits2, groud_truth2)
total_loss = (loss1 + loss2) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output

return MyModelOutput(
loss=total_loss,
logits1=logits1,
logits2=....,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

参考文献

源码models/bert/modeling_bert.py

3.2-如何应用一个BERT.md