🚨🚨🚨 [NLLB Tokenizer] Fix the prefix tokens 🚨🚨🚨 (#22313)
* fix the prefix tokens * update fast and test values * add legacy behaviour Co-authored-by: sgugger <sylvain.gugger@gmail.com> * update disclaimer, linkissue PR and behaviral changes * Apply suggestions from code review Co-authored-by: Lysandre Debut <hi@lysand.re> * styling * make a quote * quote this time --------- Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -305,6 +305,7 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
||||
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
|
||||
]
|
||||
expected_src_tokens = [
|
||||
256047,
|
||||
16297,
|
||||
134408,
|
||||
8165,
|
||||
@@ -319,7 +320,6 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
||||
108,
|
||||
49486,
|
||||
2,
|
||||
256047,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@@ -355,8 +355,8 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
||||
assert isinstance(src_text[0], str)
|
||||
desired_max_length = 10
|
||||
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
self.assertEqual(ids[-1], 2)
|
||||
self.assertEqual(ids[0], EN_CODE)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
|
||||
def test_mask_token(self):
|
||||
@@ -389,10 +389,10 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual((2, 15), batch.attention_mask.shape)
|
||||
result = batch.input_ids.tolist()[0]
|
||||
self.assertListEqual(self.expected_src_tokens, result)
|
||||
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
|
||||
self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS
|
||||
# Test that special tokens are reset
|
||||
self.assertEqual(self.tokenizer.prefix_tokens, [])
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
|
||||
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
|
||||
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
@@ -419,9 +419,27 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# A, test, EOS, en_XX
|
||||
"input_ids": [[70, 7356, 2, 256047]],
|
||||
"input_ids": [[256047, 70, 7356, 2]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 256057,
|
||||
},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_legacy_behaviour(self):
|
||||
self.tokenizer.legacy_behaviour = True
|
||||
inputs = self.tokenizer(
|
||||
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
|
||||
)
|
||||
self.assertEqual(
|
||||
inputs.input_ids, [16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2, 256047]
|
||||
)
|
||||
|
||||
self.tokenizer.legacy_behaviour = False
|
||||
inputs = self.tokenizer(
|
||||
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
|
||||
)
|
||||
self.assertEqual(
|
||||
inputs.input_ids, [256047, 16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user