From 29792864cb8cc9f0d8da4249166b407c6b91ff82 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 21 Oct 2020 11:49:58 +0200 Subject: [PATCH] [ProphetNet] Add Question Generation Model + Test (#7942) * new prophetnet model * correct name * make style --- tests/test_modeling_prophetnet.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index bac4776271..90ca042db8 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -1073,3 +1073,33 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): [EXPECTED_SUMMARIZE_100], generated_titles, ) + + @slow + def test_question_gen_inference(self): + model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg") + model.to(torch_device) + + tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg") + + INPUTS = [ + "Bill Gates [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.", + "1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.", + "April 4, 1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.", + ] + + input_ids = tokenizer(INPUTS, truncation=True, padding=True, return_tensors="pt").input_ids + input_ids = input_ids.to(torch_device) + + gen_output = model.generate(input_ids, num_beams=5, early_stopping=True) + generated_questions = tokenizer.batch_decode(gen_output, skip_special_tokens=True) + + EXPECTED_QUESTIONS = [ + "along with paul allen, who founded microsoft?", + "what year was microsoft founded?", + "on what date was microsoft founded?", + ] + + self.assertListEqual( + EXPECTED_QUESTIONS, + generated_questions, + )