update QA models tests + run_generation
This commit is contained in:
@@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model_name', type=str, default=None, required=True,
|
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||||
help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||||
|
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||||
|
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||||
parser.add_argument("--prompt", type=str, default="")
|
parser.add_argument("--prompt", type=str, default="")
|
||||||
parser.add_argument("--padding_text", type=str, default="")
|
parser.add_argument("--padding_text", type=str, default="")
|
||||||
parser.add_argument("--length", type=int, default=20)
|
parser.add_argument("--length", type=int, default=20)
|
||||||
@@ -150,15 +152,10 @@ def main():
|
|||||||
|
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
args.model_type = ""
|
args.model_type = args.model_type.lower()
|
||||||
for key in MODEL_CLASSES:
|
|
||||||
if key in args.model_name.lower():
|
|
||||||
args.model_type = key # take the first match in model types
|
|
||||||
break
|
|
||||||
|
|
||||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.model_name)
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||||
model = model_class.from_pretrained(args.model_name)
|
model = model_class.from_pretrained(args.model_name_or_path)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|||||||
@@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase):
|
|||||||
"--prompt=Hello",
|
"--prompt=Hello",
|
||||||
"--length=10",
|
"--length=10",
|
||||||
"--seed=42"]
|
"--seed=42"]
|
||||||
model_name = "--model_name=openai-gpt"
|
model_type, model_name = ("--model_type=openai-gpt",
|
||||||
|
"--model_name_or_path=openai-gpt")
|
||||||
with patch.object(sys, 'argv', testargs + [model_name]):
|
with patch.object(sys, 'argv', testargs + [model_name]):
|
||||||
result = run_generation.main()
|
result = run_generation.main()
|
||||||
self.assertGreaterEqual(len(result), 10)
|
self.assertGreaterEqual(len(result), 10)
|
||||||
|
|||||||
@@ -191,17 +191,19 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
|||||||
cls_index=sequence_labels,
|
cls_index=sequence_labels,
|
||||||
is_impossible=is_impossible_labels)
|
is_impossible=is_impossible_labels)
|
||||||
|
|
||||||
total_loss, start_logits, end_logits, cls_logits = outputs
|
(total_loss,) = outputs
|
||||||
|
|
||||||
outputs = model(input_ids, start_positions=sequence_labels,
|
outputs = model(input_ids, start_positions=sequence_labels,
|
||||||
end_positions=sequence_labels)
|
end_positions=sequence_labels)
|
||||||
|
|
||||||
total_loss, start_logits, end_logits = outputs
|
(total_loss,) = outputs
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"loss": total_loss,
|
"loss": total_loss,
|
||||||
"start_logits": start_logits,
|
"start_top_log_probs": start_top_log_probs,
|
||||||
"end_logits": end_logits,
|
"start_top_index": start_top_index,
|
||||||
|
"end_top_log_probs": end_top_log_probs,
|
||||||
|
"end_top_index": end_top_index,
|
||||||
"cls_logits": cls_logits,
|
"cls_logits": cls_logits,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,11 +211,17 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
|||||||
list(result["loss"].size()),
|
list(result["loss"].size()),
|
||||||
[])
|
[])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["start_logits"].size()),
|
list(result["start_top_log_probs"].size()),
|
||||||
[self.batch_size, self.seq_length])
|
[self.batch_size, model.config.start_n_top])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["end_logits"].size()),
|
list(result["start_top_index"].size()),
|
||||||
[self.batch_size, self.seq_length])
|
[self.batch_size, model.config.start_n_top])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["end_top_log_probs"].size()),
|
||||||
|
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["end_top_index"].size()),
|
||||||
|
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["cls_logits"].size()),
|
list(result["cls_logits"].size()),
|
||||||
[self.batch_size])
|
[self.batch_size])
|
||||||
|
|||||||
@@ -210,17 +210,19 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
cls_index=sequence_labels,
|
cls_index=sequence_labels,
|
||||||
is_impossible=is_impossible_labels)
|
is_impossible=is_impossible_labels)
|
||||||
|
|
||||||
total_loss, start_logits, end_logits, cls_logits, mems = outputs
|
total_loss, mems = outputs
|
||||||
|
|
||||||
outputs = model(input_ids_1, start_positions=sequence_labels,
|
outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||||
end_positions=sequence_labels)
|
end_positions=sequence_labels)
|
||||||
|
|
||||||
total_loss, start_logits, end_logits, mems = outputs
|
total_loss, mems = outputs
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"loss": total_loss,
|
"loss": total_loss,
|
||||||
"start_logits": start_logits,
|
"start_top_log_probs": start_top_log_probs,
|
||||||
"end_logits": end_logits,
|
"start_top_index": start_top_index,
|
||||||
|
"end_top_log_probs": end_top_log_probs,
|
||||||
|
"end_top_index": end_top_index,
|
||||||
"cls_logits": cls_logits,
|
"cls_logits": cls_logits,
|
||||||
"mems": mems,
|
"mems": mems,
|
||||||
}
|
}
|
||||||
@@ -229,11 +231,17 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
list(result["loss"].size()),
|
list(result["loss"].size()),
|
||||||
[])
|
[])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["start_logits"].size()),
|
list(result["start_top_log_probs"].size()),
|
||||||
[self.batch_size, self.seq_length])
|
[self.batch_size, model.config.start_n_top])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["end_logits"].size()),
|
list(result["start_top_index"].size()),
|
||||||
[self.batch_size, self.seq_length])
|
[self.batch_size, model.config.start_n_top])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["end_top_log_probs"].size()),
|
||||||
|
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["end_top_index"].size()),
|
||||||
|
[self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["cls_logits"].size()),
|
list(result["cls_logits"].size()),
|
||||||
[self.batch_size])
|
[self.batch_size])
|
||||||
|
|||||||
Reference in New Issue
Block a user