🚨🚨🚨 [NLLB Tokenizer] Fix the prefix tokens 🚨🚨🚨 (#22313)

* fix the prefix tokens

* update fast and test values

* add legacy behaviour

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* update disclaimer, linkissue PR and behaviral changes

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* styling

* make a quote

* quote this time

---------

Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur
2023-04-04 14:53:06 +02:00
committed by GitHub
parent ad5e9b6c6a
commit 00b5887b94
4 changed files with 111 additions and 23 deletions

View File

@@ -305,6 +305,7 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
" face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
expected_src_tokens = [
256047,
16297,
134408,
8165,
@@ -319,7 +320,6 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
108,
49486,
2,
256047,
]
@classmethod
@@ -355,8 +355,8 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
assert isinstance(src_text[0], str)
desired_max_length = 10
ids = self.tokenizer(src_text, max_length=desired_max_length, truncation=True).input_ids[0]
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(ids[-1], 2)
self.assertEqual(ids[0], EN_CODE)
self.assertEqual(len(ids), desired_max_length)
def test_mask_token(self):
@@ -389,10 +389,10 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
self.assertEqual((2, 15), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(self.expected_src_tokens, result)
self.assertEqual(2, batch.decoder_input_ids[0, -1]) # EOS
self.assertEqual(RO_CODE, batch.decoder_input_ids[0, 0]) # EOS
# Test that special tokens are reset
self.assertEqual(self.tokenizer.prefix_tokens, [])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE])
self.assertEqual(self.tokenizer.prefix_tokens, [EN_CODE])
self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
def test_seq2seq_max_length(self):
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
@@ -419,9 +419,27 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[70, 7356, 2, 256047]],
"input_ids": [[256047, 70, 7356, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 256057,
},
)
@require_torch
def test_legacy_behaviour(self):
self.tokenizer.legacy_behaviour = True
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2, 256047]
)
self.tokenizer.legacy_behaviour = False
inputs = self.tokenizer(
"UN Chief says there is no military solution in Syria", src_lang="eng_Latn", tgt_lang="fra_Latn"
)
self.assertEqual(
inputs.input_ids, [256047, 16297, 134408, 25653, 6370, 248, 254, 103929, 94995, 108, 49486, 2]
)