From 92970c0cb9826be8b935e3b50b9c0ad0e2f6c62f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 16 Apr 2021 11:31:35 +0200 Subject: [PATCH] Enabling multilingual models for translation pipelines. (#10536) * [WIP] Enabling multilingual models for translation pipelines. * decoder_input_ids -> forced_bos_token_id * Improve docstring. * Rebase * Fixing 2 bugs - Type token_ids coming from `_parse_and_tokenize` - Wrong index from tgt_lang. * Fixing black version. * Adding tests for _build_translation_inputs and add them for all tokenizers. * Mbart actually puts the lang code at the end. * Fixing m2m100. * Adding TF support to `deep_round`. * Update src/transformers/pipelines/text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding one line comment. * Fixing M2M100 `_build_translation_input_ids`, and fix the call site. * Fixing tests + deep_round -> nested_simplify Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../models/m2m_100/tokenization_m2m_100.py | 10 ++ .../models/mbart/tokenization_mbart.py | 10 ++ .../models/mbart/tokenization_mbart50.py | 10 ++ .../models/mbart/tokenization_mbart50_fast.py | 10 ++ .../models/mbart/tokenization_mbart_fast.py | 10 ++ src/transformers/pipelines/base.py | 5 +- .../pipelines/text2text_generation.py | 150 ++++++++++++------ src/transformers/testing_utils.py | 26 +++ tests/test_pipelines_translation.py | 38 ++++- tests/test_tokenization_m2m_100.py | 17 +- tests/test_tokenization_mbart.py | 17 +- tests/test_tokenization_mbart50.py | 17 +- 12 files changed, 268 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py index 3d2f273d72..e39fbbd7aa 100644 --- a/src/transformers/models/m2m_100/tokenization_m2m_100.py +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -288,6 +288,16 @@ class M2M100Tokenizer(PreTrainedTokenizer): self.set_src_lang_special_tokens(self.src_lang) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + tgt_lang_id = self.get_lang_id(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + @contextmanager def as_target_tokenizer(self): """ diff --git a/src/transformers/models/mbart/tokenization_mbart.py b/src/transformers/models/mbart/tokenization_mbart.py index a38aaf7ef3..ac5e62bda4 100644 --- a/src/transformers/models/mbart/tokenization_mbart.py +++ b/src/transformers/models/mbart/tokenization_mbart.py @@ -186,6 +186,16 @@ class MBartTokenizer(XLMRobertaTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + def prepare_seq2seq_batch( self, src_texts: List[str], diff --git a/src/transformers/models/mbart/tokenization_mbart50.py b/src/transformers/models/mbart/tokenization_mbart50.py index 5afd9b215f..48fdfe7772 100644 --- a/src/transformers/models/mbart/tokenization_mbart50.py +++ b/src/transformers/models/mbart/tokenization_mbart50.py @@ -278,6 +278,16 @@ class MBart50Tokenizer(PreTrainedTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + def prepare_seq2seq_batch( self, src_texts: List[str], diff --git a/src/transformers/models/mbart/tokenization_mbart50_fast.py b/src/transformers/models/mbart/tokenization_mbart50_fast.py index f22d02e59b..b4534b65c5 100644 --- a/src/transformers/models/mbart/tokenization_mbart50_fast.py +++ b/src/transformers/models/mbart/tokenization_mbart50_fast.py @@ -241,6 +241,16 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), ) + def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") diff --git a/src/transformers/models/mbart/tokenization_mbart_fast.py b/src/transformers/models/mbart/tokenization_mbart_fast.py index bbe9ed7d5d..4b4154e6a6 100644 --- a/src/transformers/models/mbart/tokenization_mbart_fast.py +++ b/src/transformers/models/mbart/tokenization_mbart_fast.py @@ -160,6 +160,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + """Used by translation pipeline, to prepare inputs for the generate function""" + if src_lang is None or tgt_lang is None: + raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") + self.src_lang = src_lang + inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) + inputs["forced_bos_token_id"] = tgt_lang_id + return inputs + def prepare_seq2seq_batch( self, src_texts: List[str], diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index d06376aa43..63ddd79971 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -616,7 +616,10 @@ class Pipeline(_ScikitCompat): Return: :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device. """ - return {name: tensor.to(self.device) for name, tensor in inputs.items()} + return { + name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor + for name, tensor in inputs.items() + } def check_model_type(self, supported_models: Union[List[str], dict]): """ diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index bda4457ea8..7a6564aaa4 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -1,3 +1,5 @@ +from typing import Optional + from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..tokenization_utils import TruncationStrategy from ..utils import logging @@ -50,6 +52,28 @@ class Text2TextGenerationPipeline(Pipeline): """ return True + def _parse_and_tokenize(self, *args, truncation): + prefix = self.model.config.prefix if self.model.config.prefix is not None else "" + if isinstance(args[0], list): + assert ( + self.tokenizer.pad_token_id is not None + ), "Please make sure that the tokenizer has a pad_token_id when using a batch input" + args = ([prefix + arg for arg in args[0]],) + padding = True + + elif isinstance(args[0], str): + args = (prefix + args[0],) + padding = False + else: + raise ValueError( + f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`" + ) + inputs = super()._parse_and_tokenize(*args, padding=padding, truncation=truncation) + # This is produced by tokenizers but is an invalid generate kwargs + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + return inputs + def __call__( self, *args, @@ -88,53 +112,41 @@ class Text2TextGenerationPipeline(Pipeline): """ assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" - prefix = self.model.config.prefix if self.model.config.prefix is not None else "" - if isinstance(args[0], list): - assert ( - self.tokenizer.pad_token_id is not None - ), "Please make sure that the tokenizer has a pad_token_id when using a batch input" - args = ([prefix + arg for arg in args[0]],) - padding = True - - elif isinstance(args[0], str): - args = (prefix + args[0],) - padding = False - else: - raise ValueError( - f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`" - ) - with self.device_placement(): - inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation) + inputs = self._parse_and_tokenize(*args, truncation=truncation) + return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs) - if self.framework == "pt": - inputs = self.ensure_tensor_on_device(**inputs) - input_length = inputs["input_ids"].shape[-1] - elif self.framework == "tf": - input_length = tf.shape(inputs["input_ids"])[-1].numpy() + def _generate( + self, inputs, return_tensors: bool, return_text: bool, clean_up_tokenization_spaces: bool, generate_kwargs + ): + if self.framework == "pt": + inputs = self.ensure_tensor_on_device(**inputs) + input_length = inputs["input_ids"].shape[-1] + elif self.framework == "tf": + input_length = tf.shape(inputs["input_ids"])[-1].numpy() - min_length = generate_kwargs.get("min_length", self.model.config.min_length) - max_length = generate_kwargs.get("max_length", self.model.config.max_length) - self.check_inputs(input_length, min_length, max_length) + min_length = generate_kwargs.get("min_length", self.model.config.min_length) + max_length = generate_kwargs.get("max_length", self.model.config.max_length) + self.check_inputs(input_length, min_length, max_length) - generations = self.model.generate( - inputs["input_ids"], - attention_mask=inputs["attention_mask"], - **generate_kwargs, - ) - results = [] - for generation in generations: - record = {} - if return_tensors: - record[f"{self.return_name}_token_ids"] = generation - if return_text: - record[f"{self.return_name}_text"] = self.tokenizer.decode( - generation, - skip_special_tokens=True, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - ) - results.append(record) - return results + generate_kwargs.update(inputs) + + generations = self.model.generate( + **generate_kwargs, + ) + results = [] + for generation in generations: + record = {} + if return_tensors: + record[f"{self.return_name}_token_ids"] = generation + if return_text: + record[f"{self.return_name}_text"] = self.tokenizer.decode( + generation, + skip_special_tokens=True, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ) + results.append(record) + return results @add_end_docstrings(PIPELINE_INIT_ARGS) @@ -226,6 +238,23 @@ class TranslationPipeline(Text2TextGenerationPipeline): # Used in the return key of the pipeline. return_name = "translation" + src_lang: Optional[str] = None + tgt_lang: Optional[str] = None + + def __init__(self, *args, src_lang=None, tgt_lang=None, **kwargs): + super().__init__(*args, **kwargs) + if src_lang is not None: + self.src_lang = src_lang + if tgt_lang is not None: + self.tgt_lang = tgt_lang + if src_lang is None and tgt_lang is None: + # Backward compatibility, direct arguments use is preferred. + task = kwargs.get("task", "") + items = task.split("_") + if task and len(items) == 4: + # translation, XX, to YY + self.src_lang = items[1] + self.tgt_lang = items[3] def check_inputs(self, input_length: int, min_length: int, max_length: int): if input_length > 0.9 * max_length: @@ -233,8 +262,27 @@ class TranslationPipeline(Text2TextGenerationPipeline): f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider " "increasing your max_length manually, e.g. translator('...', max_length=400)" ) + return True - def __call__(self, *args, **kwargs): + def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation): + if getattr(self.tokenizer, "_build_translation_inputs", None): + return self.tokenizer._build_translation_inputs( + *args, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation + ) + else: + return super()._parse_and_tokenize(*args, truncation=truncation) + + def __call__( + self, + *args, + return_tensors=False, + return_text=True, + clean_up_tokenization_spaces=False, + truncation=TruncationStrategy.DO_NOT_TRUNCATE, + src_lang=None, + tgt_lang=None, + **generate_kwargs + ): r""" Translate the text(s) given as inputs. @@ -247,6 +295,12 @@ class TranslationPipeline(Text2TextGenerationPipeline): Whether or not to include the decoded texts in the outputs. clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to clean up the potential extra spaces in the text output. + src_lang (:obj:`str`, `optional`, defaults to :obj:`None`): + The language of the input. Might be required for multilingual models. Will not have any effect for + single pair translation models + tgt_lang (:obj:`str`, `optional`, defaults to :obj:`None`): + The language of the desired output. Might be required for multilingual models. Will not have any effect + for single pair translation models generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method corresponding to your framework `here <./model.html#generative-models>`__). @@ -258,4 +312,10 @@ class TranslationPipeline(Text2TextGenerationPipeline): - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) -- The token ids of the translation. """ - return super().__call__(*args, **kwargs) + assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" + src_lang = src_lang if src_lang is not None else self.src_lang + tgt_lang = tgt_lang if tgt_lang is not None else self.tgt_lang + + with self.device_placement(): + inputs = self._parse_and_tokenize(*args, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang) + return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index a5c4e7d2b8..283ec1eb4d 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -361,6 +361,9 @@ if is_torch_available(): else: torch_device = None +if is_tf_available(): + import tensorflow as tf + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch. """ @@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False raise RuntimeError(f"'{cmd_str}' produced no output.") return result + + +def nested_simplify(obj, decimals=3): + """ + Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test + within tests. + """ + from transformers.tokenization_utils import BatchEncoding + + if isinstance(obj, list): + return [nested_simplify(item, decimals) for item in obj] + elif isinstance(obj, (dict, BatchEncoding)): + return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} + elif isinstance(obj, (str, int)): + return obj + elif is_torch_available() and isinstance(obj, torch.Tensor): + return nested_simplify(obj.tolist()) + elif is_tf_available() and tf.is_tensor(obj): + return nested_simplify(obj.numpy().tolist()) + elif isinstance(obj, float): + return round(obj, decimals) + else: + raise Exception(f"Not supported: {type(obj)}") diff --git a/tests/test_pipelines_translation.py b/tests/test_pipelines_translation.py index 0f866a09b7..dba66d1219 100644 --- a/tests/test_pipelines_translation.py +++ b/tests/test_pipelines_translation.py @@ -17,11 +17,15 @@ import unittest import pytest from transformers import pipeline -from transformers.testing_utils import is_pipeline_test, require_torch, slow +from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow from .test_pipelines_common import MonoInputPipelineCommonMixin +if is_torch_available(): + from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration + + class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "translation_en_to_de" small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator @@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): pipeline(task="translation_cn_to_ar") # but we do for this one - pipeline(task="translation_en_to_de") + translator = pipeline(task="translation_en_to_de") + self.assertEquals(translator.src_lang, "en") + self.assertEquals(translator.tgt_lang, "de") + + @require_torch + @slow + def test_multilingual_translation(self): + model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") + tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") + + translator = pipeline(task="translation", model=model, tokenizer=tokenizer) + # Missing src_lang, tgt_lang + with self.assertRaises(ValueError): + translator("This is a test") + + outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR") + self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}]) + + outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN") + self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}]) + + # src_lang, tgt_lang can be defined at pipeline call time + translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR") + outputs = translator("This is a test") + self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}]) @require_torch def test_translation_on_odd_language(self): model = "patrickvonplaten/t5-tiny-random" - pipeline(task="translation_cn_to_ar", model=model) + translator = pipeline(task="translation_cn_to_ar", model=model) + self.assertEquals(translator.src_lang, "cn") + self.assertEquals(translator.tgt_lang, "ar") @require_torch def test_translation_default_language_selection(self): @@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"): nlp = pipeline(task="translation", model=model) self.assertEqual(nlp.task, "translation_en_to_de") + self.assertEquals(nlp.src_lang, "en") + self.assertEquals(nlp.tgt_lang, "de") @require_torch def test_translation_with_no_language_no_model_fails(self): diff --git a/tests/test_tokenization_m2m_100.py b/tests/test_tokenization_m2m_100.py index 649d471deb..4f7cf6ffae 100644 --- a/tests/test_tokenization_m2m_100.py +++ b/tests/test_tokenization_m2m_100.py @@ -20,7 +20,7 @@ from shutil import copyfile from transformers import M2M100Tokenizer, is_torch_available from transformers.file_utils import is_sentencepiece_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch +from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch if is_sentencepiece_available(): @@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase): self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")]) self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)]) + + @require_torch + def test_tokenizer_translation(self): + inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar") + + self.assertEqual( + nested_simplify(inputs), + { + # en_XX, A, test, EOS + "input_ids": [[128022, 58, 4183, 2]], + "attention_mask": [[1, 1, 1, 1]], + # ar_AR + "forced_bos_token_id": 128006, + }, + ) diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index 83c2d33b6f..640aec60fd 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -17,7 +17,7 @@ import tempfile import unittest from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch +from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch from .test_tokenization_common import TokenizerTesterMixin @@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase): self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 10) + + @require_torch + def test_tokenizer_translation(self): + inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR") + + self.assertEqual( + nested_simplify(inputs), + { + # A, test, EOS, en_XX + "input_ids": [[62, 3034, 2, 250004]], + "attention_mask": [[1, 1, 1, 1]], + # ar_AR + "forced_bos_token_id": 250001, + }, + ) diff --git a/tests/test_tokenization_mbart50.py b/tests/test_tokenization_mbart50.py index 4c3561a907..49dfc0b66f 100644 --- a/tests/test_tokenization_mbart50.py +++ b/tests/test_tokenization_mbart50.py @@ -17,7 +17,7 @@ import tempfile import unittest from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available -from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch +from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch from .test_tokenization_common import TokenizerTesterMixin @@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 10) + + @require_torch + def test_tokenizer_translation(self): + inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR") + + self.assertEqual( + nested_simplify(inputs), + { + # en_XX, A, test, EOS + "input_ids": [[250004, 62, 3034, 2]], + "attention_mask": [[1, 1, 1, 1]], + # ar_AR + "forced_bos_token_id": 250001, + }, + )