add special tokens to pretrain configs of respective lm head models
This commit is contained in:
@@ -513,14 +513,8 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
def prepare_generation_special_tokens():
|
||||
return {"bos_token_id": 1, "pad_token_id": 5, "eos_token_id": 2}
|
||||
|
||||
|
||||
class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
special_tokens = prepare_generation_special_tokens()
|
||||
|
||||
@slow
|
||||
def test_lm_generate_xlnet_base_cased(self):
|
||||
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
|
||||
@@ -917,12 +911,6 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
# Since, however, he has had difficulty walking with Maria
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
bos_token_id=self.special_tokens["bos_token_id"],
|
||||
pad_token_id=self.special_tokens["pad_token_id"],
|
||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||
max_length=200,
|
||||
)
|
||||
output_ids = model.generate(input_ids, max_length=200)
|
||||
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user