Update Special Language Tokens for PLBART (#19980)
* Update Special Language Tokens for PLBART * fix format * making mapping for language codes and updating tests: * fix format * fix consistency * add assert to both tokenizer tests. * fix format * Update src/transformers/models/plbart/tokenization_plbart.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * improvin readability, setting self.tgt_lang * fixing * readability Co-authored-by: jordiclive <jordiclive19@imperial.ac.uk> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -129,7 +129,14 @@ class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
end = tokenizer.vocab_size
|
||||
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 4, end)]
|
||||
|
||||
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "<mask>"])
|
||||
self.assertListEqual(language_tokens, ["__java__", "__python__", "__en_XX__", "<mask>"])
|
||||
|
||||
code = "java.lang.Exception, python.lang.Exception, javascript, php, ruby, go"
|
||||
input_ids = tokenizer(code).input_ids
|
||||
self.assertEqual(
|
||||
tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False),
|
||||
code,
|
||||
)
|
||||
|
||||
def test_full_multi_tokenizer(self):
|
||||
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="multi", keep_accents=True)
|
||||
@@ -208,7 +215,15 @@ class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
end = tokenizer.vocab_size
|
||||
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 7, end)]
|
||||
|
||||
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "javascript", "php", "ruby", "go"])
|
||||
self.assertListEqual(
|
||||
language_tokens, ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"]
|
||||
)
|
||||
code = "java.lang.Exception, python.lang.Exception, javascript, php, ruby, go"
|
||||
input_ids = tokenizer(code).input_ids
|
||||
self.assertEqual(
|
||||
tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False),
|
||||
code,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -262,9 +277,9 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
||||
return cls
|
||||
|
||||
def check_language_codes(self):
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["java"], 50001)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["python"], 50002)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_XX"], 50003)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__java__"], 50001)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__python__"], 50002)
|
||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__en_XX__"], 50003)
|
||||
|
||||
def test_python_en_tokenizer_batch_encode_plus(self):
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
@@ -288,7 +303,7 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
|
||||
def test_mask_token(self):
|
||||
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "java"]), [50004, 50001])
|
||||
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "__java__"]), [50004, 50001])
|
||||
|
||||
def test_special_tokens_unaffacted_by_save_load(self):
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
Reference in New Issue
Block a user