adding attention outputs in bert
This commit is contained in:
@@ -133,11 +133,28 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
}
|
||||
return outputs
|
||||
|
||||
def create_gpt2_lm_head_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = GPT2LMHeadModel(config, output_attentions=True)
|
||||
model.eval()
|
||||
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
|
||||
attentions, lm_logits, presents = model(input_ids, position_ids, token_type_ids)
|
||||
outputs = {
|
||||
"loss": loss,
|
||||
"lm_logits": lm_logits,
|
||||
"presents": presents,
|
||||
"attentions": attentions,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_gpt2_lm_head_output(self, result):
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
self.parent.assertListEqual(
|
||||
list(result["presents"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
|
||||
def check_gpt2_lm_head_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
@@ -160,6 +177,23 @@ class GPT2ModelTest(unittest.TestCase):
|
||||
}
|
||||
return outputs
|
||||
|
||||
def create_gpt2_double_heads_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
model = GPT2DoubleHeadsModel(config, output_attentions=True)
|
||||
model.eval()
|
||||
loss = model(input_ids, mc_token_ids,
|
||||
lm_labels=lm_labels, mc_labels=mc_labels,
|
||||
token_type_ids=token_type_ids, position_ids=position_ids)
|
||||
attentions, lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
outputs = {
|
||||
"loss": loss,
|
||||
"lm_logits": lm_logits,
|
||||
"mc_logits": mc_logits,
|
||||
"presents": presents,
|
||||
"attentions": attentions,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def check_gpt2_double_heads_output(self, result):
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
|
||||
Reference in New Issue
Block a user