From d7092d592ca55391b3c07505539b9e4c71bf79de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 12:51:14 +0200 Subject: [PATCH] rename the attributes in the Bert Layer Since the preloading of weights relies on the name of the class's attributes changing the namespace breaks loading pretrained weights on Bert and all related models. I reverted `self_attention` to `attention` and us `crossattention` for the decoder instead. --- transformers/modeling_bert.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index fddf5d52a2..5fcf41a1e1 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -321,25 +321,24 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): def __init__(self, config): super(BertLayer, self).__init__() - self.self_attention = BertAttention(config) + self.attention = BertAttention(config) if getattr(config, "is_decoder", False): - self.attention = BertAttention(config) + self.crossattention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None): - self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask) - self_attention_output = self_attention_outputs[0] + attention_outputs = self.attention(hidden_states, attention_mask, head_mask) + attention_output = attention_outputs[0] - attention_outputs = self_attention_outputs if encoder_hidden_state: try: - attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state) + crossattention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state) except AttributeError as ae: raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer") - attention_output = attention_outputs[0] - intermediate_output = self.intermediate(attention_output) + crossattention_output = crossattention_outputs[0] + intermediate_output = self.intermediate(crossattention_output) layer_output = self.output(intermediate_output, attention_output) outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them return outputs @@ -633,7 +632,7 @@ class BertModel(BertPreTrainedModel): See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].self_attention.prune_heads(heads) + self.encoder.layer[layer].attention.prune_heads(heads) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): if attention_mask is None: @@ -737,7 +736,7 @@ class BertDecoderModel(BertPreTrainedModel): """ for layer, heads in heads_to_prune.items(): self.decoder.layer[layer].attention.prune_heads(heads) - self.decoder.layer[layer].self_attention.prune_heads(heads) + self.decoder.layer[layer].crossattention.prune_heads(heads) def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): if attention_mask is None: