[ProphetNet] Add Question Generation Model + Test (#7942)
* new prophetnet model * correct name * make style
This commit is contained in:
committed by
GitHub
parent
13842e413c
commit
29792864cb
@@ -1073,3 +1073,33 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
|
||||
[EXPECTED_SUMMARIZE_100],
|
||||
generated_titles,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_question_gen_inference(self):
|
||||
model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg")
|
||||
model.to(torch_device)
|
||||
|
||||
tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased-squad-qg")
|
||||
|
||||
INPUTS = [
|
||||
"Bill Gates [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
|
||||
"1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
|
||||
"April 4, 1975 [SEP] Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975.",
|
||||
]
|
||||
|
||||
input_ids = tokenizer(INPUTS, truncation=True, padding=True, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
gen_output = model.generate(input_ids, num_beams=5, early_stopping=True)
|
||||
generated_questions = tokenizer.batch_decode(gen_output, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_QUESTIONS = [
|
||||
"along with paul allen, who founded microsoft?",
|
||||
"what year was microsoft founded?",
|
||||
"on what date was microsoft founded?",
|
||||
]
|
||||
|
||||
self.assertListEqual(
|
||||
EXPECTED_QUESTIONS,
|
||||
generated_questions,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user