update QA models tests + run_generation
This commit is contained in:
@@ -131,8 +131,10 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str, default=None, required=True,
|
||||
help="GPT, GPT-2, Transformer-XL or XLNet pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--prompt", type=str, default="")
|
||||
parser.add_argument("--padding_text", type=str, default="")
|
||||
parser.add_argument("--length", type=int, default=20)
|
||||
@@ -150,15 +152,10 @@ def main():
|
||||
|
||||
set_seed(args)
|
||||
|
||||
args.model_type = ""
|
||||
for key in MODEL_CLASSES:
|
||||
if key in args.model_name.lower():
|
||||
args.model_type = key # take the first match in model types
|
||||
break
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name)
|
||||
model = model_class.from_pretrained(args.model_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
|
||||
@@ -101,7 +101,8 @@ class ExamplesTests(unittest.TestCase):
|
||||
"--prompt=Hello",
|
||||
"--length=10",
|
||||
"--seed=42"]
|
||||
model_name = "--model_name=openai-gpt"
|
||||
model_type, model_name = ("--model_type=openai-gpt",
|
||||
"--model_name_or_path=openai-gpt")
|
||||
with patch.object(sys, 'argv', testargs + [model_name]):
|
||||
result = run_generation.main()
|
||||
self.assertGreaterEqual(len(result), 10)
|
||||
|
||||
Reference in New Issue
Block a user