diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 2f6f1dfdbb..0ec1be4812 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) -def prepare_generation_special_tokens(): - return {"bos_token_id": 50256, "eos_token_id": 50256} - - class GPT2ModelLanguageGenerationTest(unittest.TestCase): - special_tokens = prepare_generation_special_tokens() - @slow def test_lm_generate_gpt2(self): model = GPT2LMHeadModel.from_pretrained("gpt2") @@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ] # The dog is cute too. It likes to rub on me and is good for me (the dog torch.manual_seed(0) - output_ids = model.generate( - input_ids, - bos_token_id=self.special_tokens["bos_token_id"], - eos_token_ids=self.special_tokens["eos_token_id"], - ) + output_ids = model.generate(input_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids) @@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ] # The dog is cute though he can sometimes just walk in the park. It is not very nice to torch.manual_seed(0) - output_ids = model.generate( - input_ids, - bos_token_id=self.special_tokens["bos_token_id"], - eos_token_ids=self.special_tokens["eos_token_id"], - ) - + output_ids = model.generate(input_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids) diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 32987b6ac2..f59ed628eb 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -214,14 +214,8 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) -def prepare_generation_special_tokens(): - return {"eos_token_id": 0} - - class TransfoXLModelLanguageGenerationTest(unittest.TestCase): - special_tokens = prepare_generation_special_tokens() - @slow def test_lm_generate_transfo_xl_wt103(self): model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103") @@ -578,6 +572,5 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase): torch.manual_seed(0) - output_ids = model.generate(input_ids, eos_token_ids=self.special_tokens["eos_token_id"], max_length=200) - + output_ids = model.generate(input_ids, max_length=200) self.assertListEqual(output_ids[0].tolist(), expected_output_ids) diff --git a/tests/test_modeling_xlm.py b/tests/test_modeling_xlm.py index 99e6608917..bf7d7b485b 100644 --- a/tests/test_modeling_xlm.py +++ b/tests/test_modeling_xlm.py @@ -399,14 +399,8 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) -def prepare_generation_special_tokens(): - return {"bos_token_id": 0, "pad_token_id": 2} - - class XLMModelLanguageGenerationTest(unittest.TestCase): - special_tokens = prepare_generation_special_tokens() - @slow def test_lm_generate_xlm_mlm_en_2048(self): model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048") @@ -435,10 +429,6 @@ class XLMModelLanguageGenerationTest(unittest.TestCase): ] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation. torch.manual_seed(0) - output_ids = model.generate( - input_ids, - bos_token_id=self.special_tokens["bos_token_id"], - pad_token_id=self.special_tokens["pad_token_id"], - ) + output_ids = model.generate(input_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids) diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 7989430802..160da136e5 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) -def prepare_generation_special_tokens(): - return {"bos_token_id": 1, "pad_token_id": 5, "eos_token_id": 2} - - class XLNetModelLanguageGenerationTest(unittest.TestCase): - special_tokens = prepare_generation_special_tokens() - @slow def test_lm_generate_xlnet_base_cased(self): model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased") @@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase): # Since, however, he has had difficulty walking with Maria torch.manual_seed(0) - output_ids = model.generate( - input_ids, - bos_token_id=self.special_tokens["bos_token_id"], - pad_token_id=self.special_tokens["pad_token_id"], - eos_token_ids=self.special_tokens["eos_token_id"], - max_length=200, - ) + output_ids = model.generate(input_ids, max_length=200) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)