Generate: fix plbart generation tests (#20391)
This commit is contained in:
@@ -409,12 +409,12 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
src_text = ["Is 0 the first Fibonacci number ?", "Find the sum of all prime numbers ."]
|
src_text = ["Is 0 the first Fibonacci number ?", "Find the sum of all prime numbers ."]
|
||||||
tgt_text = ["0 the first Fibonacci number?", "the sum of all prime numbers.......... the the"]
|
tgt_text = ["0 the first Fibonacci number?", "the sum of all prime numbers.......... the the"]
|
||||||
|
|
||||||
@unittest.skip("This test is broken, fix me gante")
|
|
||||||
def test_base_generate(self):
|
def test_base_generate(self):
|
||||||
inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
|
inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
|
||||||
|
src_lan = self.tokenizer._convert_lang_code_special_format("en_XX")
|
||||||
translated_tokens = self.model.generate(
|
translated_tokens = self.model.generate(
|
||||||
input_ids=inputs["input_ids"].to(torch_device),
|
input_ids=inputs["input_ids"].to(torch_device),
|
||||||
decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"],
|
decoder_start_token_id=self.tokenizer.lang_code_to_id[src_lan],
|
||||||
)
|
)
|
||||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||||
@@ -422,8 +422,9 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
@slow
|
@slow
|
||||||
def test_fill_mask(self):
|
def test_fill_mask(self):
|
||||||
inputs = self.tokenizer(["Is 0 the <mask> Fibonacci <mask> ?"], return_tensors="pt").to(torch_device)
|
inputs = self.tokenizer(["Is 0 the <mask> Fibonacci <mask> ?"], return_tensors="pt").to(torch_device)
|
||||||
|
src_lan = self.tokenizer._convert_lang_code_special_format("en_XX")
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id[src_lan], num_beams=1
|
||||||
)
|
)
|
||||||
prediction: str = self.tokenizer.batch_decode(
|
prediction: str = self.tokenizer.batch_decode(
|
||||||
outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
|
outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
|
|||||||
Reference in New Issue
Block a user