From 16263f9685eaf459408f33c9790a967012b93fa5 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 7 Nov 2019 17:29:29 +0000 Subject: [PATCH] Headmasking --- transformers/modeling_albert.py | 11 ++++++----- transformers/tests/modeling_albert_test.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/transformers/modeling_albert.py b/transformers/modeling_albert.py index 89ece4b61e..6682930d89 100644 --- a/transformers/modeling_albert.py +++ b/transformers/modeling_albert.py @@ -224,7 +224,7 @@ class AlbertLayer(nn.Module): self.activation = ACT2FN[config.hidden_act] def forward(self, hidden_states, attention_mask=None, head_mask=None): - attention_output = self.attention(hidden_states, attention_mask) + attention_output = self.attention(hidden_states, attention_mask, head_mask) ffn_output = self.ffn(attention_output[0]) ffn_output = self.activation(ffn_output) ffn_output = self.ffn_output(ffn_output) @@ -245,8 +245,8 @@ class AlbertLayerGroup(nn.Module): layer_hidden_states = () layer_attentions = () - for albert_layer in self.albert_layers: - layer_output = albert_layer(hidden_states, attention_mask, head_mask) + for layer_index, albert_layer in enumerate(self.albert_layers): + layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index]) hidden_states = layer_output[0] if self.output_attentions: @@ -283,7 +283,8 @@ class AlbertTransformer(nn.Module): for layer_idx in range(self.config.num_hidden_layers): group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups) - layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask) + layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) + layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group]) hidden_states = layer_group_output[0] @@ -544,7 +545,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_labels=None): - outputs = self.albert(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None) + outputs = self.albert(input_ids, attention_mask, token_type_ids, position_ids, head_mask) sequence_outputs = outputs[0] prediction_scores = self.predictions(sequence_outputs) diff --git a/transformers/tests/modeling_albert_test.py b/transformers/tests/modeling_albert_test.py index 979e0488eb..466f473332 100644 --- a/transformers/tests/modeling_albert_test.py +++ b/transformers/tests/modeling_albert_test.py @@ -35,7 +35,6 @@ else: class AlbertModelTest(CommonTestCases.CommonModelTester): all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else () - test_head_masking = False class AlbertModelTester(object):