Update Special Language Tokens for PLBART (#19980)
* Update Special Language Tokens for PLBART * fix format * making mapping for language codes and updating tests: * fix format * fix consistency * add assert to both tokenizer tests. * fix format * Update src/transformers/models/plbart/tokenization_plbart.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * improvin readability, setting self.tgt_lang * fixing * readability Co-authored-by: jordiclive <jordiclive19@imperial.ac.uk> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
0
src/transformers/models/plbart/modeling_plbart.py
Executable file → Normal file
0
src/transformers/models/plbart/modeling_plbart.py
Executable file → Normal file
@@ -88,8 +88,18 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
FAIRSEQ_LANGUAGE_CODES = {
|
FAIRSEQ_LANGUAGE_CODES = {
|
||||||
"base": ["java", "python", "en_XX"],
|
"base": ["__java__", "__python__", "__en_XX__"],
|
||||||
"multi": ["java", "python", "en_XX", "javascript", "php", "ruby", "go"],
|
"multi": ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"],
|
||||||
|
}
|
||||||
|
|
||||||
|
FAIRSEQ_LANGUAGE_CODES_MAP = {
|
||||||
|
"java": "__java__",
|
||||||
|
"python": "__python__",
|
||||||
|
"en_XX": "__en_XX__",
|
||||||
|
"javascript": "__javascript__",
|
||||||
|
"php": "__php__",
|
||||||
|
"ruby": "__ruby__",
|
||||||
|
"go": "__go__",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -202,6 +212,8 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||||||
sp_model_kwargs=self.sp_model_kwargs,
|
sp_model_kwargs=self.sp_model_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
src_lang = self._convert_lang_code_special_format(src_lang)
|
||||||
|
tgt_lang = self._convert_lang_code_special_format(tgt_lang)
|
||||||
|
|
||||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||||
self.sp_model.Load(str(vocab_file))
|
self.sp_model.Load(str(vocab_file))
|
||||||
@@ -247,7 +259,7 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||||||
self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang
|
self.lang_code_to_id[self._src_lang] if self._src_lang is not None else self._src_lang
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
self._src_lang = src_lang if src_lang is not None else "__en_XX__"
|
||||||
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||||
|
|
||||||
self.tgt_lang = tgt_lang
|
self.tgt_lang = tgt_lang
|
||||||
@@ -284,6 +296,7 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
@src_lang.setter
|
@src_lang.setter
|
||||||
def src_lang(self, new_src_lang: str) -> None:
|
def src_lang(self, new_src_lang: str) -> None:
|
||||||
|
new_src_lang = self._convert_lang_code_special_format(new_src_lang)
|
||||||
self._src_lang = new_src_lang
|
self._src_lang = new_src_lang
|
||||||
self.set_src_lang_special_tokens(self._src_lang)
|
self.set_src_lang_special_tokens(self._src_lang)
|
||||||
|
|
||||||
@@ -374,9 +387,10 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
if src_lang is None or tgt_lang is None:
|
if src_lang is None or tgt_lang is None:
|
||||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
self.src_lang = src_lang
|
self.src_lang = self._convert_lang_code_special_format(src_lang)
|
||||||
|
self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)
|
||||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
|
||||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
tgt_lang_id = self.convert_tokens_to_ids(self.tgt_lang)
|
||||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@@ -433,8 +447,8 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||||||
tgt_lang: str = "python",
|
tgt_lang: str = "python",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
self.src_lang = src_lang
|
self.src_lang = self._convert_lang_code_special_format(src_lang)
|
||||||
self.tgt_lang = tgt_lang
|
self.tgt_lang = self._convert_lang_code_special_format(tgt_lang)
|
||||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||||
|
|
||||||
def _switch_to_input_mode(self):
|
def _switch_to_input_mode(self):
|
||||||
@@ -445,6 +459,7 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||||
|
src_lang = self._convert_lang_code_special_format(src_lang)
|
||||||
self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None
|
self.cur_lang_code = self.lang_code_to_id[src_lang] if src_lang is not None else None
|
||||||
self.prefix_tokens = []
|
self.prefix_tokens = []
|
||||||
if self.cur_lang_code is not None:
|
if self.cur_lang_code is not None:
|
||||||
@@ -454,9 +469,16 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
def set_tgt_lang_special_tokens(self, lang: str) -> None:
|
||||||
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
|
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
|
||||||
|
lang = self._convert_lang_code_special_format(lang)
|
||||||
|
|
||||||
self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None
|
self.cur_lang_code = self.lang_code_to_id[lang] if lang is not None else None
|
||||||
self.prefix_tokens = []
|
self.prefix_tokens = []
|
||||||
if self.cur_lang_code is not None:
|
if self.cur_lang_code is not None:
|
||||||
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
self.suffix_tokens = [self.eos_token_id, self.cur_lang_code]
|
||||||
else:
|
else:
|
||||||
self.suffix_tokens = [self.eos_token_id]
|
self.suffix_tokens = [self.eos_token_id]
|
||||||
|
|
||||||
|
def _convert_lang_code_special_format(self, lang: str) -> str:
|
||||||
|
"""Convert Language Codes to format tokenizer uses if required"""
|
||||||
|
lang = FAIRSEQ_LANGUAGE_CODES_MAP[lang] if lang in FAIRSEQ_LANGUAGE_CODES_MAP.keys() else lang
|
||||||
|
return lang
|
||||||
|
|||||||
@@ -129,7 +129,14 @@ class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
end = tokenizer.vocab_size
|
end = tokenizer.vocab_size
|
||||||
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 4, end)]
|
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 4, end)]
|
||||||
|
|
||||||
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "<mask>"])
|
self.assertListEqual(language_tokens, ["__java__", "__python__", "__en_XX__", "<mask>"])
|
||||||
|
|
||||||
|
code = "java.lang.Exception, python.lang.Exception, javascript, php, ruby, go"
|
||||||
|
input_ids = tokenizer(code).input_ids
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False),
|
||||||
|
code,
|
||||||
|
)
|
||||||
|
|
||||||
def test_full_multi_tokenizer(self):
|
def test_full_multi_tokenizer(self):
|
||||||
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="multi", keep_accents=True)
|
tokenizer = PLBartTokenizer(SAMPLE_VOCAB, language_codes="multi", keep_accents=True)
|
||||||
@@ -208,7 +215,15 @@ class PLBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
end = tokenizer.vocab_size
|
end = tokenizer.vocab_size
|
||||||
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 7, end)]
|
language_tokens = [tokenizer.convert_ids_to_tokens(x) for x in range(end - 7, end)]
|
||||||
|
|
||||||
self.assertListEqual(language_tokens, ["java", "python", "en_XX", "javascript", "php", "ruby", "go"])
|
self.assertListEqual(
|
||||||
|
language_tokens, ["__java__", "__python__", "__en_XX__", "__javascript__", "__php__", "__ruby__", "__go__"]
|
||||||
|
)
|
||||||
|
code = "java.lang.Exception, python.lang.Exception, javascript, php, ruby, go"
|
||||||
|
input_ids = tokenizer(code).input_ids
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False),
|
||||||
|
code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -262,9 +277,9 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
def check_language_codes(self):
|
def check_language_codes(self):
|
||||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["java"], 50001)
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__java__"], 50001)
|
||||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["python"], 50002)
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__python__"], 50002)
|
||||||
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_XX"], 50003)
|
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["__en_XX__"], 50003)
|
||||||
|
|
||||||
def test_python_en_tokenizer_batch_encode_plus(self):
|
def test_python_en_tokenizer_batch_encode_plus(self):
|
||||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||||
@@ -288,7 +303,7 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(ids), desired_max_length)
|
self.assertEqual(len(ids), desired_max_length)
|
||||||
|
|
||||||
def test_mask_token(self):
|
def test_mask_token(self):
|
||||||
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "java"]), [50004, 50001])
|
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "__java__"]), [50004, 50001])
|
||||||
|
|
||||||
def test_special_tokens_unaffacted_by_save_load(self):
|
def test_special_tokens_unaffacted_by_save_load(self):
|
||||||
tmpdirname = tempfile.mkdtemp()
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
|||||||
Reference in New Issue
Block a user