adding attention outputs in bert
This commit is contained in:
@@ -275,12 +275,13 @@ class BertEmbeddings(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertSelfAttention(nn.Module):
|
class BertSelfAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertSelfAttention, self).__init__()
|
super(BertSelfAttention, self).__init__()
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The hidden size (%d) is not a multiple of the number of attention "
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
@@ -322,6 +323,8 @@ class BertSelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
if self.output_attentions:
|
||||||
|
return attention_probs, context_layer
|
||||||
return context_layer
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
@@ -340,14 +343,19 @@ class BertSelfOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertAttention(nn.Module):
|
class BertAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertAttention, self).__init__()
|
super(BertAttention, self).__init__()
|
||||||
self.self = BertSelfAttention(config)
|
self.output_attentions = output_attentions
|
||||||
|
self.self = BertSelfAttention(config, output_attentions=output_attentions)
|
||||||
self.output = BertSelfOutput(config)
|
self.output = BertSelfOutput(config)
|
||||||
|
|
||||||
def forward(self, input_tensor, attention_mask):
|
def forward(self, input_tensor, attention_mask):
|
||||||
self_output = self.self(input_tensor, attention_mask)
|
self_output = self.self(input_tensor, attention_mask)
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions, self_output = self_output
|
||||||
attention_output = self.output(self_output, input_tensor)
|
attention_output = self.output(self_output, input_tensor)
|
||||||
|
if self.output_attentions:
|
||||||
|
return attentions, attention_output
|
||||||
return attention_output
|
return attention_output
|
||||||
|
|
||||||
|
|
||||||
@@ -381,33 +389,45 @@ class BertOutput(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class BertLayer(nn.Module):
|
class BertLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertLayer, self).__init__()
|
super(BertLayer, self).__init__()
|
||||||
self.attention = BertAttention(config)
|
self.output_attentions = output_attentions
|
||||||
|
self.attention = BertAttention(config, output_attentions=output_attentions)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask):
|
def forward(self, hidden_states, attention_mask):
|
||||||
attention_output = self.attention(hidden_states, attention_mask)
|
attention_output = self.attention(hidden_states, attention_mask)
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions, attention_output = attention_output
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
|
if self.output_attentions:
|
||||||
|
return attentions, layer_output
|
||||||
return layer_output
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertEncoder, self).__init__()
|
super(BertEncoder, self).__init__()
|
||||||
layer = BertLayer(config)
|
self.output_attentions = output_attentions
|
||||||
|
layer = BertLayer(config, output_attentions=output_attentions)
|
||||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||||
all_encoder_layers = []
|
all_encoder_layers = []
|
||||||
|
all_attentions = []
|
||||||
for layer_module in self.layer:
|
for layer_module in self.layer:
|
||||||
hidden_states = layer_module(hidden_states, attention_mask)
|
hidden_states = layer_module(hidden_states, attention_mask)
|
||||||
|
if self.output_attentions:
|
||||||
|
attentions, hidden_states = hidden_states
|
||||||
|
all_attentions.append(attentions)
|
||||||
if output_all_encoded_layers:
|
if output_all_encoded_layers:
|
||||||
all_encoder_layers.append(hidden_states)
|
all_encoder_layers.append(hidden_states)
|
||||||
if not output_all_encoded_layers:
|
if not output_all_encoded_layers:
|
||||||
all_encoder_layers.append(hidden_states)
|
all_encoder_layers.append(hidden_states)
|
||||||
|
if self.output_attentions:
|
||||||
|
return all_attentions, all_encoder_layers
|
||||||
return all_encoder_layers
|
return all_encoder_layers
|
||||||
|
|
||||||
|
|
||||||
@@ -699,10 +719,11 @@ class BertModel(BertPreTrainedModel):
|
|||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertModel, self).__init__(config)
|
super(BertModel, self).__init__(config)
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.embeddings = BertEmbeddings(config)
|
||||||
self.encoder = BertEncoder(config)
|
self.encoder = BertEncoder(config, output_attentions=output_attentions)
|
||||||
self.pooler = BertPooler(config)
|
self.pooler = BertPooler(config)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
@@ -731,10 +752,14 @@ class BertModel(BertPreTrainedModel):
|
|||||||
encoded_layers = self.encoder(embedding_output,
|
encoded_layers = self.encoder(embedding_output,
|
||||||
extended_attention_mask,
|
extended_attention_mask,
|
||||||
output_all_encoded_layers=output_all_encoded_layers)
|
output_all_encoded_layers=output_all_encoded_layers)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, encoded_layers = encoded_layers
|
||||||
sequence_output = encoded_layers[-1]
|
sequence_output = encoded_layers[-1]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
if not output_all_encoded_layers:
|
if not output_all_encoded_layers:
|
||||||
encoded_layers = encoded_layers[-1]
|
encoded_layers = encoded_layers[-1]
|
||||||
|
if self.output_attentions:
|
||||||
|
return all_attentions, encoded_layers, pooled_output
|
||||||
return encoded_layers, pooled_output
|
return encoded_layers, pooled_output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -133,11 +133,28 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def create_gpt2_lm_head_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
|
||||||
|
mc_labels, lm_labels, mc_token_ids):
|
||||||
|
model = GPT2LMHeadModel(config, output_attentions=True)
|
||||||
|
model.eval()
|
||||||
|
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
|
||||||
|
attentions, lm_logits, presents = model(input_ids, position_ids, token_type_ids)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"lm_logits": lm_logits,
|
||||||
|
"presents": presents,
|
||||||
|
"attentions": attentions,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
def check_gpt2_lm_head_output(self, result):
|
def check_gpt2_lm_head_output(self, result):
|
||||||
total_voc = self.n_special + self.vocab_size
|
total_voc = self.n_special + self.vocab_size
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits"].size()),
|
list(result["lm_logits"].size()),
|
||||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["presents"].size()),
|
||||||
|
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||||
|
|
||||||
def check_gpt2_lm_head_loss_output(self, result):
|
def check_gpt2_lm_head_loss_output(self, result):
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
@@ -160,6 +177,23 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def create_gpt2_double_heads_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
|
||||||
|
mc_labels, lm_labels, mc_token_ids):
|
||||||
|
model = GPT2DoubleHeadsModel(config, output_attentions=True)
|
||||||
|
model.eval()
|
||||||
|
loss = model(input_ids, mc_token_ids,
|
||||||
|
lm_labels=lm_labels, mc_labels=mc_labels,
|
||||||
|
token_type_ids=token_type_ids, position_ids=position_ids)
|
||||||
|
attentions, lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"lm_logits": lm_logits,
|
||||||
|
"mc_logits": mc_logits,
|
||||||
|
"presents": presents,
|
||||||
|
"attentions": attentions,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
def check_gpt2_double_heads_output(self, result):
|
def check_gpt2_double_heads_output(self, result):
|
||||||
total_voc = self.n_special + self.vocab_size
|
total_voc = self.n_special + self.vocab_size
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
|
|||||||
Reference in New Issue
Block a user