update QA models tests + run_generation

This commit is contained in:
thomwolf
2019-07-15 17:45:24 +02:00
parent 15d8b1266c
commit e691fc0963
4 changed files with 41 additions and 27 deletions

View File

@@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
cls_index=sequence_labels,
is_impossible=is_impossible_labels)
total_loss, start_logits, end_logits, cls_logits = outputs
(total_loss,) = outputs
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels)
total_loss, start_logits, end_logits = outputs
(total_loss,) = outputs
result = {
"loss": total_loss,
"start_logits": start_logits,
"end_logits": end_logits,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
}
@@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
list(result["start_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
list(result["start_top_index"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
[self.batch_size])

View File

@@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
cls_index=sequence_labels,
is_impossible=is_impossible_labels)
total_loss, start_logits, end_logits, cls_logits, mems = outputs
total_loss, mems = outputs
outputs = model(input_ids_1, start_positions=sequence_labels,
end_positions=sequence_labels)
total_loss, start_logits, end_logits, mems = outputs
total_loss, mems = outputs
result = {
"loss": total_loss,
"start_logits": start_logits,
"end_logits": end_logits,
"start_top_log_probs": start_top_log_probs,
"start_top_index": start_top_index,
"end_top_log_probs": end_top_log_probs,
"end_top_index": end_top_index,
"cls_logits": cls_logits,
"mems": mems,
}
@@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list(result["loss"].size()),
[])
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
list(result["start_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
list(result["start_top_index"].size()),
[self.batch_size, model.config.start_n_top])
self.parent.assertListEqual(
list(result["end_top_log_probs"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["end_top_index"].size()),
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
self.parent.assertListEqual(
list(result["cls_logits"].size()),
[self.batch_size])