MbartTokenizer: do not hardcode vocab size (#5998)
This commit is contained in:
@@ -113,10 +113,15 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.tokenizer: MBartTokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
def check_language_codes(self):
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ar_AR"], 250001)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
|
||||
|
||||
def test_enro_tokenizer_prepare_translation_batch(self):
|
||||
batch = self.tokenizer.prepare_translation_batch(
|
||||
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
|
||||
|
||||
Reference in New Issue
Block a user