Enabling multilingual models for translation pipelines. (#10536)
* [WIP] Enabling multilingual models for translation pipelines. * decoder_input_ids -> forced_bos_token_id * Improve docstring. * Rebase * Fixing 2 bugs - Type token_ids coming from `_parse_and_tokenize` - Wrong index from tgt_lang. * Fixing black version. * Adding tests for _build_translation_inputs and add them for all tokenizers. * Mbart actually puts the lang code at the end. * Fixing m2m100. * Adding TF support to `deep_round`. * Update src/transformers/pipelines/text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding one line comment. * Fixing M2M100 `_build_translation_input_ids`, and fix the call site. * Fixing tests + deep_round -> nested_simplify Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -288,6 +288,16 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
tgt_lang_id = self.get_lang_id(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
|
||||
@@ -186,6 +186,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
|
||||
@@ -278,6 +278,16 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
|
||||
@@ -241,6 +241,16 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
||||
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
||||
)
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
|
||||
@@ -160,6 +160,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||
if src_lang is None or tgt_lang is None:
|
||||
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||
self.src_lang = src_lang
|
||||
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
|
||||
@@ -616,7 +616,10 @@ class Pipeline(_ScikitCompat):
|
||||
Return:
|
||||
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
|
||||
"""
|
||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||||
return {
|
||||
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
|
||||
for name, tensor in inputs.items()
|
||||
}
|
||||
|
||||
def check_model_type(self, supported_models: Union[List[str], dict]):
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||
from ..tokenization_utils import TruncationStrategy
|
||||
from ..utils import logging
|
||||
@@ -50,6 +52,28 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
return True
|
||||
|
||||
def _parse_and_tokenize(self, *args, truncation):
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
if isinstance(args[0], list):
|
||||
assert (
|
||||
self.tokenizer.pad_token_id is not None
|
||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||
args = ([prefix + arg for arg in args[0]],)
|
||||
padding = True
|
||||
|
||||
elif isinstance(args[0], str):
|
||||
args = (prefix + args[0],)
|
||||
padding = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
|
||||
)
|
||||
inputs = super()._parse_and_tokenize(*args, padding=padding, truncation=truncation)
|
||||
# This is produced by tokenizers but is an invalid generate kwargs
|
||||
if "token_type_ids" in inputs:
|
||||
del inputs["token_type_ids"]
|
||||
return inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args,
|
||||
@@ -88,25 +112,13 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
"""
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
|
||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||
if isinstance(args[0], list):
|
||||
assert (
|
||||
self.tokenizer.pad_token_id is not None
|
||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||
args = ([prefix + arg for arg in args[0]],)
|
||||
padding = True
|
||||
|
||||
elif isinstance(args[0], str):
|
||||
args = (prefix + args[0],)
|
||||
padding = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
|
||||
)
|
||||
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)
|
||||
inputs = self._parse_and_tokenize(*args, truncation=truncation)
|
||||
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
|
||||
|
||||
def _generate(
|
||||
self, inputs, return_tensors: bool, return_text: bool, clean_up_tokenization_spaces: bool, generate_kwargs
|
||||
):
|
||||
if self.framework == "pt":
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
input_length = inputs["input_ids"].shape[-1]
|
||||
@@ -117,9 +129,9 @@ class Text2TextGenerationPipeline(Pipeline):
|
||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||
self.check_inputs(input_length, min_length, max_length)
|
||||
|
||||
generate_kwargs.update(inputs)
|
||||
|
||||
generations = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
**generate_kwargs,
|
||||
)
|
||||
results = []
|
||||
@@ -226,6 +238,23 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
||||
|
||||
# Used in the return key of the pipeline.
|
||||
return_name = "translation"
|
||||
src_lang: Optional[str] = None
|
||||
tgt_lang: Optional[str] = None
|
||||
|
||||
def __init__(self, *args, src_lang=None, tgt_lang=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if src_lang is not None:
|
||||
self.src_lang = src_lang
|
||||
if tgt_lang is not None:
|
||||
self.tgt_lang = tgt_lang
|
||||
if src_lang is None and tgt_lang is None:
|
||||
# Backward compatibility, direct arguments use is preferred.
|
||||
task = kwargs.get("task", "")
|
||||
items = task.split("_")
|
||||
if task and len(items) == 4:
|
||||
# translation, XX, to YY
|
||||
self.src_lang = items[1]
|
||||
self.tgt_lang = items[3]
|
||||
|
||||
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
||||
if input_length > 0.9 * max_length:
|
||||
@@ -233,8 +262,27 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
||||
f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
|
||||
"increasing your max_length manually, e.g. translator('...', max_length=400)"
|
||||
)
|
||||
return True
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation):
|
||||
if getattr(self.tokenizer, "_build_translation_inputs", None):
|
||||
return self.tokenizer._build_translation_inputs(
|
||||
*args, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation
|
||||
)
|
||||
else:
|
||||
return super()._parse_and_tokenize(*args, truncation=truncation)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*args,
|
||||
return_tensors=False,
|
||||
return_text=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
**generate_kwargs
|
||||
):
|
||||
r"""
|
||||
Translate the text(s) given as inputs.
|
||||
|
||||
@@ -247,6 +295,12 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
||||
Whether or not to include the decoded texts in the outputs.
|
||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to clean up the potential extra spaces in the text output.
|
||||
src_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The language of the input. Might be required for multilingual models. Will not have any effect for
|
||||
single pair translation models
|
||||
tgt_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The language of the desired output. Might be required for multilingual models. Will not have any effect
|
||||
for single pair translation models
|
||||
generate_kwargs:
|
||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||
@@ -258,4 +312,10 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
||||
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||
-- The token ids of the translation.
|
||||
"""
|
||||
return super().__call__(*args, **kwargs)
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
src_lang = src_lang if src_lang is not None else self.src_lang
|
||||
tgt_lang = tgt_lang if tgt_lang is not None else self.tgt_lang
|
||||
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*args, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang)
|
||||
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
|
||||
|
||||
@@ -361,6 +361,9 @@ if is_torch_available():
|
||||
else:
|
||||
torch_device = None
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch. """
|
||||
@@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False
|
||||
raise RuntimeError(f"'{cmd_str}' produced no output.")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def nested_simplify(obj, decimals=3):
|
||||
"""
|
||||
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
|
||||
within tests.
|
||||
"""
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
|
||||
if isinstance(obj, list):
|
||||
return [nested_simplify(item, decimals) for item in obj]
|
||||
elif isinstance(obj, (dict, BatchEncoding)):
|
||||
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
|
||||
elif isinstance(obj, (str, int)):
|
||||
return obj
|
||||
elif is_torch_available() and isinstance(obj, torch.Tensor):
|
||||
return nested_simplify(obj.tolist())
|
||||
elif is_tf_available() and tf.is_tensor(obj):
|
||||
return nested_simplify(obj.numpy().tolist())
|
||||
elif isinstance(obj, float):
|
||||
return round(obj, decimals)
|
||||
else:
|
||||
raise Exception(f"Not supported: {type(obj)}")
|
||||
|
||||
@@ -17,11 +17,15 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch, slow
|
||||
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
|
||||
|
||||
|
||||
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "translation_en_to_de"
|
||||
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
||||
@@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
||||
pipeline(task="translation_cn_to_ar")
|
||||
|
||||
# but we do for this one
|
||||
pipeline(task="translation_en_to_de")
|
||||
translator = pipeline(task="translation_en_to_de")
|
||||
self.assertEquals(translator.src_lang, "en")
|
||||
self.assertEquals(translator.tgt_lang, "de")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_multilingual_translation(self):
|
||||
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||
|
||||
translator = pipeline(task="translation", model=model, tokenizer=tokenizer)
|
||||
# Missing src_lang, tgt_lang
|
||||
with self.assertRaises(ValueError):
|
||||
translator("This is a test")
|
||||
|
||||
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
|
||||
|
||||
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN")
|
||||
self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])
|
||||
|
||||
# src_lang, tgt_lang can be defined at pipeline call time
|
||||
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR")
|
||||
outputs = translator("This is a test")
|
||||
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
|
||||
|
||||
@require_torch
|
||||
def test_translation_on_odd_language(self):
|
||||
model = "patrickvonplaten/t5-tiny-random"
|
||||
pipeline(task="translation_cn_to_ar", model=model)
|
||||
translator = pipeline(task="translation_cn_to_ar", model=model)
|
||||
self.assertEquals(translator.src_lang, "cn")
|
||||
self.assertEquals(translator.tgt_lang, "ar")
|
||||
|
||||
@require_torch
|
||||
def test_translation_default_language_selection(self):
|
||||
@@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
||||
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
||||
nlp = pipeline(task="translation", model=model)
|
||||
self.assertEqual(nlp.task, "translation_en_to_de")
|
||||
self.assertEquals(nlp.src_lang, "en")
|
||||
self.assertEquals(nlp.tgt_lang, "de")
|
||||
|
||||
@require_torch
|
||||
def test_translation_with_no_language_no_model_fails(self):
|
||||
|
||||
@@ -20,7 +20,7 @@ from shutil import copyfile
|
||||
|
||||
from transformers import M2M100Tokenizer, is_torch_available
|
||||
from transformers.file_utils import is_sentencepiece_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# en_XX, A, test, EOS
|
||||
"input_ids": [[128022, 58, 4183, 2]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 128006,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# A, test, EOS, en_XX
|
||||
"input_ids": [[62, 3034, 2, 250004]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 250001,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# en_XX, A, test, EOS
|
||||
"input_ids": [[250004, 62, 3034, 2]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 250001,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user