rename the attributes in the Bert Layer
Since the preloading of weights relies on the name of the class's attributes changing the namespace breaks loading pretrained weights on Bert and all related models. I reverted `self_attention` to `attention` and us `crossattention` for the decoder instead.
This commit is contained in:
@@ -321,25 +321,24 @@ class BertOutput(nn.Module):
|
|||||||
class BertLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertLayer, self).__init__()
|
super(BertLayer, self).__init__()
|
||||||
self.self_attention = BertAttention(config)
|
|
||||||
if getattr(config, "is_decoder", False):
|
|
||||||
self.attention = BertAttention(config)
|
self.attention = BertAttention(config)
|
||||||
|
if getattr(config, "is_decoder", False):
|
||||||
|
self.crossattention = BertAttention(config)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
||||||
self_attention_outputs = self.self_attention(hidden_states, attention_mask, head_mask)
|
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||||
self_attention_output = self_attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
attention_outputs = self_attention_outputs
|
|
||||||
if encoder_hidden_state:
|
if encoder_hidden_state:
|
||||||
try:
|
try:
|
||||||
attention_outputs = self.attention(self_attention_output, attention_mask, head_mask, encoder_hidden_state)
|
crossattention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||||
except AttributeError as ae:
|
except AttributeError as ae:
|
||||||
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
|
raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer")
|
||||||
|
|
||||||
attention_output = attention_outputs[0]
|
crossattention_output = crossattention_outputs[0]
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(crossattention_output)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||||
return outputs
|
return outputs
|
||||||
@@ -633,7 +632,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
See base class PreTrainedModel
|
See base class PreTrainedModel
|
||||||
"""
|
"""
|
||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].self_attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
@@ -737,7 +736,7 @@ class BertDecoderModel(BertPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.decoder.layer[layer].attention.prune_heads(heads)
|
self.decoder.layer[layer].attention.prune_heads(heads)
|
||||||
self.decoder.layer[layer].self_attention.prune_heads(heads)
|
self.decoder.layer[layer].crossattention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user