update QA models tests + run_generation
This commit is contained in:
@@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, cls_logits = outputs
|
||||
(total_loss,) = outputs
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels)
|
||||
|
||||
total_loss, start_logits, end_logits = outputs
|
||||
(total_loss,) = outputs
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
"start_top_log_probs": start_top_log_probs,
|
||||
"start_top_index": start_top_index,
|
||||
"end_top_log_probs": end_top_log_probs,
|
||||
"end_top_index": end_top_index,
|
||||
"cls_logits": cls_logits,
|
||||
}
|
||||
|
||||
@@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["cls_logits"].size()),
|
||||
[self.batch_size])
|
||||
|
||||
@@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
cls_index=sequence_labels,
|
||||
is_impossible=is_impossible_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, cls_logits, mems = outputs
|
||||
total_loss, mems = outputs
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels)
|
||||
|
||||
total_loss, start_logits, end_logits, mems = outputs
|
||||
total_loss, mems = outputs
|
||||
|
||||
result = {
|
||||
"loss": total_loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
"start_top_log_probs": start_top_log_probs,
|
||||
"start_top_index": start_top_index,
|
||||
"end_top_log_probs": end_top_log_probs,
|
||||
"end_top_index": end_top_index,
|
||||
"cls_logits": cls_logits,
|
||||
"mems": mems,
|
||||
}
|
||||
@@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
list(result["start_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_log_probs"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_index"].size()),
|
||||
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
self.parent.assertListEqual(
|
||||
list(result["cls_logits"].size()),
|
||||
[self.batch_size])
|
||||
|
||||
Reference in New Issue
Block a user