ALBERT passes all tests

This commit is contained in:
Lysandre
2019-10-31 14:04:10 +00:00
committed by Lysandre Debut
parent 870320a24e
commit c14a22272f
3 changed files with 5 additions and 10 deletions

View File

@@ -202,17 +202,14 @@ class AlbertLayerGroup(nn.Module):
layer_attentions = ()
for albert_layer in self.albert_layers:
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
layer_output = albert_layer(hidden_states, attention_mask, head_mask)
hidden_states = layer_output[0]
if self.output_attentions:
layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
@@ -247,7 +244,7 @@ class AlbertTransformer(nn.Module):
hidden_states = layer_group_output[0]
if self.output_attentions:
all_attentions = all_attentions + layer_group_output[1]
all_attentions = all_attentions + layer_group_output[-1]
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)