add output_attentions for BertModel
This commit is contained in:
@@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertSelfAttention(nn.Module):
|
class BertSelfAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertSelfAttention, self).__init__()
|
super(BertSelfAttention, self).__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -291,6 +291,8 @@ class BertSelfAttention(nn.Module):
|
|||||||
|
|
||||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
|
self.output_attentions = output_attentions
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
@@ -322,7 +324,10 @@ class BertSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
return context_layer
|
if self.output_attentions:
|
||||||
|
return attention_probs, context_layer
|
||||||
|
else:
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
class BertSelfOutput(nn.Module):
|
class BertSelfOutput(nn.Module):
|
||||||
@@ -381,33 +386,43 @@ class BertOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertLayer, self).__init__()
|
super(BertLayer, self).__init__()
|
||||||
self.attention = BertAttention(config)
|
self.attention = BertAttention(config)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
self.output_attentions = output_attentions
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask):
|
def forward(self, hidden_states, attention_mask):
|
||||||
attention_output = self.attention(hidden_states, attention_mask)
|
attention_output = self.attention(hidden_states, attention_mask)
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
|
if self.output_attentions:
|
||||||
|
return attention_output, layer_output
|
||||||
return layer_output
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertEncoder, self).__init__()
|
super(BertEncoder, self).__init__()
|
||||||
layer = BertLayer(config)
|
layer = BertLayer(config, output_attentions=output_attentions)
|
||||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.output_attentions = output_attentions
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||||
all_encoder_layers = []
|
all_encoder_layers = []
|
||||||
|
all_attentions = []
|
||||||
for layer_module in self.layer:
|
for layer_module in self.layer:
|
||||||
hidden_states = layer_module(hidden_states, attention_mask)
|
hidden_states = layer_module(hidden_states, attention_mask)
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions, hidden_states = hidden_states
|
||||||
|
all_attentions.append(attentions)
|
||||||
if output_all_encoded_layers:
|
if output_all_encoded_layers:
|
||||||
all_encoder_layers.append(hidden_states)
|
all_encoder_layers.append(hidden_states)
|
||||||
if not output_all_encoded_layers:
|
if not output_all_encoded_layers:
|
||||||
all_encoder_layers.append(hidden_states)
|
all_encoder_layers.append(hidden_states)
|
||||||
|
if self.output_attentions:
|
||||||
|
return all_attentions, all_encoder_layers
|
||||||
return all_encoder_layers
|
return all_encoder_layers
|
||||||
|
|
||||||
|
|
||||||
@@ -699,12 +714,13 @@ class BertModel(BertPreTrainedModel):
|
|||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertModel, self).__init__(config)
|
super(BertModel, self).__init__(config)
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.encoder = BertEncoder(config)
|
self.encoder = BertEncoder(config, output_attentions=output_attentions)
|
||||||
self.pooler = BertPooler(config)
|
self.pooler = BertPooler(config)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
self.output_attentions = output_attentions
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
@@ -731,10 +747,14 @@ class BertModel(BertPreTrainedModel):
|
|||||||
encoded_layers = self.encoder(embedding_output,
|
encoded_layers = self.encoder(embedding_output,
|
||||||
extended_attention_mask,
|
extended_attention_mask,
|
||||||
output_all_encoded_layers=output_all_encoded_layers)
|
output_all_encoded_layers=output_all_encoded_layers)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, encoded_layers = encoded_layers
|
||||||
sequence_output = encoded_layers[-1]
|
sequence_output = encoded_layers[-1]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
if not output_all_encoded_layers:
|
if not output_all_encoded_layers:
|
||||||
encoded_layers = encoded_layers[-1]
|
encoded_layers = encoded_layers[-1]
|
||||||
|
if self.output_attentions:
|
||||||
|
return all_attentions, encoded_layers, pooled_output
|
||||||
return encoded_layers, pooled_output
|
return encoded_layers, pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user