add special tokens to pretrain configs of respective lm head models

This commit is contained in:
Patrick von Platen
2020-02-25 16:37:56 +01:00
parent e693cd1e87
commit e645dcbb70
4 changed files with 5 additions and 49 deletions

View File

@@ -263,14 +263,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def prepare_generation_special_tokens():
return {"bos_token_id": 50256, "eos_token_id": 50256}
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
special_tokens = prepare_generation_special_tokens()
@slow
def test_lm_generate_gpt2(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
@@ -299,11 +293,7 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute too. It likes to rub on me and is good for me (the dog
torch.manual_seed(0)
output_ids = model.generate(
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@@ -335,10 +325,5 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
] # The dog is cute though he can sometimes just walk in the park. It is not very nice to
torch.manual_seed(0)
output_ids = model.generate(
input_ids,
bos_token_id=self.special_tokens["bos_token_id"],
eos_token_ids=self.special_tokens["eos_token_id"],
)
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)