From a3274ac40b14025ee857897ecfaff4fb07bcb61d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 3 Jun 2019 16:11:45 -0500 Subject: [PATCH] adding attention outputs in bert --- pytorch_pretrained_bert/modeling.py | 43 +++++++++++++++++++++++------ tests/modeling_gpt2_test.py | 34 +++++++++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index b9b6837193..27682eb369 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -275,12 +275,13 @@ class BertEmbeddings(nn.Module): class BertSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.output_attentions = output_attentions self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size @@ -322,6 +323,8 @@ class BertSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) + if self.output_attentions: + return attention_probs, context_layer return context_layer @@ -340,14 +343,19 @@ class BertSelfOutput(nn.Module): class BertAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertAttention, self).__init__() - self.self = BertSelfAttention(config) + self.output_attentions = output_attentions + self.self = BertSelfAttention(config, output_attentions=output_attentions) self.output = BertSelfOutput(config) def forward(self, input_tensor, attention_mask): self_output = self.self(input_tensor, attention_mask) + if self.output_attentions: + attentions, self_output = self_output attention_output = self.output(self_output, input_tensor) + if self.output_attentions: + return attentions, attention_output return attention_output @@ -381,33 +389,45 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertLayer, self).__init__() - self.attention = BertAttention(config) + self.output_attentions = output_attentions + self.attention = BertAttention(config, output_attentions=output_attentions) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask): attention_output = self.attention(hidden_states, attention_mask) + if self.output_attentions: + attentions, attention_output = attention_output intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) + if self.output_attentions: + return attentions, layer_output return layer_output class BertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertEncoder, self).__init__() - layer = BertLayer(config) + self.output_attentions = output_attentions + layer = BertLayer(config, output_attentions=output_attentions) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): all_encoder_layers = [] + all_attentions = [] for layer_module in self.layer: hidden_states = layer_module(hidden_states, attention_mask) + if self.output_attentions: + attentions, hidden_states = hidden_states + all_attentions.append(attentions) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) + if self.output_attentions: + return all_attentions, all_encoder_layers return all_encoder_layers @@ -699,10 +719,11 @@ class BertModel(BertPreTrainedModel): all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertModel, self).__init__(config) + self.output_attentions = output_attentions self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) + self.encoder = BertEncoder(config, output_attentions=output_attentions) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) @@ -731,10 +752,14 @@ class BertModel(BertPreTrainedModel): encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers) + if self.output_attentions: + all_attentions, encoded_layers = encoded_layers sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] + if self.output_attentions: + return all_attentions, encoded_layers, pooled_output return encoded_layers, pooled_output diff --git a/tests/modeling_gpt2_test.py b/tests/modeling_gpt2_test.py index 6804b794c5..41cc9b8fd3 100644 --- a/tests/modeling_gpt2_test.py +++ b/tests/modeling_gpt2_test.py @@ -133,11 +133,28 @@ class GPT2ModelTest(unittest.TestCase): } return outputs + def create_gpt2_lm_head_with_output_attention(self, config, input_ids, token_type_ids, position_ids, + mc_labels, lm_labels, mc_token_ids): + model = GPT2LMHeadModel(config, output_attentions=True) + model.eval() + loss = model(input_ids, position_ids, token_type_ids, lm_labels) + attentions, lm_logits, presents = model(input_ids, position_ids, token_type_ids) + outputs = { + "loss": loss, + "lm_logits": lm_logits, + "presents": presents, + "attentions": attentions, + } + return outputs + def check_gpt2_lm_head_output(self, result): total_voc = self.n_special + self.vocab_size self.parent.assertListEqual( list(result["lm_logits"].size()), [self.batch_size, self.n_choices, self.seq_length, total_voc]) + self.parent.assertListEqual( + list(result["presents"].size()), + [self.batch_size, self.n_choices, self.seq_length, total_voc]) def check_gpt2_lm_head_loss_output(self, result): self.parent.assertListEqual( @@ -160,6 +177,23 @@ class GPT2ModelTest(unittest.TestCase): } return outputs + def create_gpt2_double_heads_with_output_attention(self, config, input_ids, token_type_ids, position_ids, + mc_labels, lm_labels, mc_token_ids): + model = GPT2DoubleHeadsModel(config, output_attentions=True) + model.eval() + loss = model(input_ids, mc_token_ids, + lm_labels=lm_labels, mc_labels=mc_labels, + token_type_ids=token_type_ids, position_ids=position_ids) + attentions, lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids) + outputs = { + "loss": loss, + "lm_logits": lm_logits, + "mc_logits": mc_logits, + "presents": presents, + "attentions": attentions, + } + return outputs + def check_gpt2_double_heads_output(self, result): total_voc = self.n_special + self.vocab_size self.parent.assertListEqual(