From e53331c905de1af4c7c31563db86cfabd4f1e0f9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Nov 2022 17:56:04 +0000 Subject: [PATCH] Generate: fix plbart generation tests (#20391) --- tests/models/plbart/test_modeling_plbart.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index a17d2d0db6..38eca39b28 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -409,12 +409,12 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest): src_text = ["Is 0 the first Fibonacci number ?", "Find the sum of all prime numbers ."] tgt_text = ["0 the first Fibonacci number?", "the sum of all prime numbers.......... the the"] - @unittest.skip("This test is broken, fix me gante") def test_base_generate(self): inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device) + src_lan = self.tokenizer._convert_lang_code_special_format("en_XX") translated_tokens = self.model.generate( input_ids=inputs["input_ids"].to(torch_device), - decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], + decoder_start_token_id=self.tokenizer.lang_code_to_id[src_lan], ) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) self.assertEqual(self.tgt_text[0], decoded[0]) @@ -422,8 +422,9 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest): @slow def test_fill_mask(self): inputs = self.tokenizer(["Is 0 the Fibonacci ?"], return_tensors="pt").to(torch_device) + src_lan = self.tokenizer._convert_lang_code_special_format("en_XX") outputs = self.model.generate( - inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1 + inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id[src_lan], num_beams=1 ) prediction: str = self.tokenizer.batch_decode( outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True