From 25a31953e8820eb5c88d8ad35ee547efccfe577c Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 30 Oct 2019 21:18:06 +0000 Subject: [PATCH] Output Attentions + output hidden states --- transformers/modeling_albert.py | 58 ++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 9 deletions(-) diff --git a/transformers/modeling_albert.py b/transformers/modeling_albert.py index b45208b696..52cab2ea69 100644 --- a/transformers/modeling_albert.py +++ b/transformers/modeling_albert.py @@ -105,6 +105,7 @@ class AlbertAttention(BertSelfAttention): def __init__(self, config): super(AlbertAttention, self).__init__(config) + self.output_attentions = config.output_attentions self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.attention_head_size = config.hidden_size // config.num_attention_heads @@ -177,7 +178,7 @@ class AlbertAttention(BertSelfAttention): projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b projected_context_layer_dropout = self.dropout(projected_context_layer) layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout) - return layernormed_context_layer + return (layernormed_context_layer, attention_probs) if self.output_attentions else (layernormed_context_layer,) class AlbertLayer(nn.Module): @@ -193,25 +194,45 @@ class AlbertLayer(nn.Module): def forward(self, hidden_states, attention_mask=None, head_mask=None): attention_output = self.attention(hidden_states, attention_mask) - ffn_output = self.ffn(attention_output) + ffn_output = self.ffn(attention_output[0]) ffn_output = self.activation(ffn_output) ffn_output = self.ffn_output(ffn_output) - hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) + hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) - return hidden_states + return (hidden_states,) + attention_output[1:] # add attentions if we output them class AlbertLayerGroup(nn.Module): def __init__(self, config): super(AlbertLayerGroup, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)]) def forward(self, hidden_states, attention_mask=None, head_mask=None): - for albert_layer in self.albert_layers: - hidden_states = albert_layer(hidden_states, attention_mask, head_mask) + layer_hidden_states = () + layer_attentions = () - return hidden_states + for albert_layer in self.albert_layers: + if self.output_hidden_states: + layer_hidden_states = layer_hidden_states + (hidden_states,) + + layer_output = albert_layer(hidden_states, attention_mask, head_mask) + hidden_states = layer_output[0] + + if self.output_attentions: + layer_attentions = layer_attentions + (layer_output[1],) + + if self.output_hidden_states: + layer_hidden_states = layer_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (layer_hidden_states,) + if self.output_attentions: + outputs = outputs + (layer_attentions,) + return outputs # last-layer hidden state, (layer hidden states), (layer attentions) class AlbertTransformer(nn.Module): @@ -227,11 +248,30 @@ class AlbertTransformer(nn.Module): def forward(self, hidden_states, attention_mask=None, head_mask=None): hidden_states = self.embedding_hidden_mapping_in(hidden_states) + all_attentions = () + + if self.output_hidden_states: + all_hidden_states = (hidden_states,) + for layer_idx in range(self.config.num_hidden_layers): group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups) - hidden_states = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask) + layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask) - return (hidden_states,) + hidden_states = layer_group_output[0] + + if self.output_attentions: + all_attentions = all_attentions + layer_group_output[1] + + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + return outputs # last-layer hidden state, (all hidden states), (all attentions) ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in