From d9daad98c744e115bfb425f316a5e7d4f405a9a5 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 7 Nov 2019 19:55:43 +0000 Subject: [PATCH] Re-ordering of group_idx/layer_idx + Python 2 tests --- transformers/modeling_albert.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/transformers/modeling_albert.py b/transformers/modeling_albert.py index 6682930d89..640af537d0 100644 --- a/transformers/modeling_albert.py +++ b/transformers/modeling_albert.py @@ -281,11 +281,17 @@ class AlbertTransformer(nn.Module): if self.output_hidden_states: all_hidden_states = (hidden_states,) - for layer_idx in range(self.config.num_hidden_layers): - group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups) + for i in range(self.config.num_hidden_layers): + # Number of layers in a hidden group layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) + + # Index of the hidden group + group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) + + # Index of the layer inside the group + layer_idx = int(i - group_idx * layers_per_group) + layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group]) - hidden_states = layer_group_output[0] if self.output_attentions: