diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index b9b6837193..72d8ff5195 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module): class BertSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( @@ -291,6 +291,8 @@ class BertSelfAttention(nn.Module): self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.output_attentions = output_attentions + def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) @@ -322,7 +324,10 @@ class BertSelfAttention(nn.Module): context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - return context_layer + if self.output_attentions: + return attention_probs, context_layer + else: + return context_layer class BertSelfOutput(nn.Module): @@ -381,33 +386,43 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertLayer, self).__init__() self.attention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) + self.output_attentions = output_attentions def forward(self, hidden_states, attention_mask): attention_output = self.attention(hidden_states, attention_mask) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) + if self.output_attentions: + return attention_output, layer_output return layer_output class BertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertEncoder, self).__init__() - layer = BertLayer(config) + layer = BertLayer(config, output_attentions=output_attentions) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + self.output_attentions = output_attentions def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): all_encoder_layers = [] + all_attentions = [] for layer_module in self.layer: hidden_states = layer_module(hidden_states, attention_mask) + if self.output_attentions: + attentions, hidden_states = hidden_states + all_attentions.append(attentions) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) + if self.output_attentions: + return all_attentions, all_encoder_layers return all_encoder_layers @@ -699,12 +714,13 @@ class BertModel(BertPreTrainedModel): all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertModel, self).__init__(config) self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) + self.encoder = BertEncoder(config, output_attentions=output_attentions) self.pooler = BertPooler(config) self.apply(self.init_bert_weights) + self.output_attentions = output_attentions def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): if attention_mask is None: @@ -731,10 +747,14 @@ class BertModel(BertPreTrainedModel): encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers) + if self.output_attentions: + all_attentions, encoded_layers = encoded_layers sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] + if self.output_attentions: + return all_attentions, encoded_layers, pooled_output return encoded_layers, pooled_output