rename the attributes in the Bert Layer

Since the preloading of weights relies on the name of the class's
attributes changing the namespace breaks loading pretrained weights on
Bert and all related models. I reverted `self_attention` to `attention`
and us `crossattention` for the decoder instead.
This commit is contained in:
Rémi Louf
2019-10-10 12:51:14 +02:00
parent 51261167b4
commit d7092d592c

View File

@@ -321,25 +321,24 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.self_attention = BertAttention(config) self.attention = BertAttention(config)
if getattr(config, "is_decoder", False): if getattr(config, "is_decoder", False):
self.attention = BertAttention(config) self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask) attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
self_attention_output = self_attention_outputs[0] attention_output = attention_outputs[0]
attention_outputs = self_attention_outputs
if encoder_hidden_state: if encoder_hidden_state:
try: try:
attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state) crossattention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
except AttributeError as ae: except AttributeError as ae:
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer") raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
attention_output = attention_outputs[0] crossattention_output = crossattention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(crossattention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
@@ -633,7 +632,7 @@ class BertModel(BertPreTrainedModel):
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].self_attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None:
@@ -737,7 +736,7 @@ class BertDecoderModel(BertPreTrainedModel):
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.decoder.layer[layer].attention.prune_heads(heads) self.decoder.layer[layer].attention.prune_heads(heads)
self.decoder.layer[layer].self_attention.prune_heads(heads) self.decoder.layer[layer].crossattention.prune_heads(heads)
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None: if attention_mask is None: