adding attention outputs in bert

This commit is contained in:
thomwolf
2019-06-03 16:11:45 -05:00
parent 826496580b
commit a3274ac40b
2 changed files with 68 additions and 9 deletions

View File

@@ -275,12 +275,13 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = output_attentions
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
@@ -322,6 +323,8 @@ class BertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
if self.output_attentions:
return attention_probs, context_layer
return context_layer
@@ -340,14 +343,19 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output_attentions = output_attentions
self.self = BertSelfAttention(config, output_attentions=output_attentions)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
if self.output_attentions:
attentions, self_output = self_output
attention_output = self.output(self_output, input_tensor)
if self.output_attentions:
return attentions, attention_output
return attention_output
@@ -381,33 +389,45 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module):
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.output_attentions = output_attentions
self.attention = BertAttention(config, output_attentions=output_attentions)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
if self.output_attentions:
attentions, attention_output = attention_output
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
if self.output_attentions:
return attentions, layer_output
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.output_attentions = output_attentions
layer = BertLayer(config, output_attentions=output_attentions)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
all_attentions = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if self.output_attentions:
attentions, hidden_states = hidden_states
all_attentions.append(attentions)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if self.output_attentions:
return all_attentions, all_encoder_layers
return all_encoder_layers
@@ -699,10 +719,11 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertModel, self).__init__(config)
self.output_attentions = output_attentions
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.encoder = BertEncoder(config, output_attentions=output_attentions)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
@@ -731,10 +752,14 @@ class BertModel(BertPreTrainedModel):
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
if self.output_attentions:
all_attentions, encoded_layers = encoded_layers
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
if self.output_attentions:
return all_attentions, encoded_layers, pooled_output
return encoded_layers, pooled_output