[ProphetNet] Add Question Generation Model + Test (#7942)
* new prophetnet model * correct name * make style
This commit is contained in:
committed by
GitHub
parent
13842e413c
commit
29792864cb
@@ -1073,3 +1073,33 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
|
|||||||
[EXPECTED_SUMMARIZE_100],
|
[EXPECTED_SUMMARIZE_100],
|
||||||
generated_titles,
|
generated_titles,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_question_gen_inference(self):
|
||||||
|
model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg")
|
||||||
|
|
||||||
|
INPUTS = [
|
||||||
|
"Bill Gates [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
|
||||||
|
"1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
|
||||||
|
"April 4, 1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = tokenizer(INPUTS, truncation=True, padding=True, return_tensors="pt").input_ids
|
||||||
|
input_ids = input_ids.to(torch_device)
|
||||||
|
|
||||||
|
gen_output = model.generate(input_ids, num_beams=5, early_stopping=True)
|
||||||
|
generated_questions = tokenizer.batch_decode(gen_output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_QUESTIONS = [
|
||||||
|
"along with paul allen, who founded microsoft?",
|
||||||
|
"what year was microsoft founded?",
|
||||||
|
"on what date was microsoft founded?",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
EXPECTED_QUESTIONS,
|
||||||
|
generated_questions,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user