update QA models tests + run_generation

This commit is contained in:
thomwolf
2019-07-15 17:45:24 +02:00
parent 15d8b1266c
commit e691fc0963
4 changed files with 41 additions and 27 deletions

View File

@@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default=None, required=True,
help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--length", type=int, default=20)
@@ -150,15 +152,10 @@ def main():
set_seed(args)
args.model_type = ""
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name)
model = model_class.from_pretrained(args.model_name)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
model.eval()

View File

@@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase):
"--prompt=Hello",
"--length=10",
"--seed=42"]
model_name = "--model_name=openai-gpt"
model_type, model_name = ("--model_type=openai-gpt",
"--model_name_or_path=openai-gpt")
with patch.object(sys, 'argv', testargs + [model_name]):
result = run_generation.main()
self.assertGreaterEqual(len(result), 10)