Enabling multilingual models for translation pipelines. (#10536)
* [WIP] Enabling multilingual models for translation pipelines. * decoder_input_ids -> forced_bos_token_id * Improve docstring. * Rebase * Fixing 2 bugs - Type token_ids coming from `_parse_and_tokenize` - Wrong index from tgt_lang. * Fixing black version. * Adding tests for _build_translation_inputs and add them for all tokenizers. * Mbart actually puts the lang code at the end. * Fixing m2m100. * Adding TF support to `deep_round`. * Update src/transformers/pipelines/text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding one line comment. * Fixing M2M100 `_build_translation_input_ids`, and fix the call site. * Fixing tests + deep_round -> nested_simplify Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -288,6 +288,16 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
|||||||
self.set_src_lang_special_tokens(self.src_lang)
|
self.set_src_lang_special_tokens(self.src_lang)
|
||||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||||
|
|
||||||
|
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||||
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
|
if src_lang is None or tgt_lang is None:
|
||||||
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
|
self.src_lang = src_lang
|
||||||
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||||
|
tgt_lang_id = self.get_lang_id(tgt_lang)
|
||||||
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
|
return inputs
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def as_target_tokenizer(self):
|
def as_target_tokenizer(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -186,6 +186,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
|
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||||
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
|
if src_lang is None or tgt_lang is None:
|
||||||
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
|
self.src_lang = src_lang
|
||||||
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||||
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
|
return inputs
|
||||||
|
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
|
|||||||
@@ -278,6 +278,16 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
|
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||||
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
|
if src_lang is None or tgt_lang is None:
|
||||||
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
|
self.src_lang = src_lang
|
||||||
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||||
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
|
return inputs
|
||||||
|
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
|
|||||||
@@ -241,6 +241,16 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||||
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
|
if src_lang is None or tgt_lang is None:
|
||||||
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
|
self.src_lang = src_lang
|
||||||
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||||
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
|
return inputs
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not os.path.isdir(save_directory):
|
if not os.path.isdir(save_directory):
|
||||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||||
|
|||||||
@@ -160,6 +160,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
|||||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||||
|
|
||||||
|
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
|
||||||
|
"""Used by translation pipeline, to prepare inputs for the generate function"""
|
||||||
|
if src_lang is None or tgt_lang is None:
|
||||||
|
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
|
||||||
|
self.src_lang = src_lang
|
||||||
|
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
|
||||||
|
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
|
||||||
|
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||||
|
return inputs
|
||||||
|
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
src_texts: List[str],
|
src_texts: List[str],
|
||||||
|
|||||||
@@ -616,7 +616,10 @@ class Pipeline(_ScikitCompat):
|
|||||||
Return:
|
Return:
|
||||||
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
|
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
|
||||||
"""
|
"""
|
||||||
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
return {
|
||||||
|
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
|
||||||
|
for name, tensor in inputs.items()
|
||||||
|
}
|
||||||
|
|
||||||
def check_model_type(self, supported_models: Union[List[str], dict]):
|
def check_model_type(self, supported_models: Union[List[str], dict]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
||||||
from ..tokenization_utils import TruncationStrategy
|
from ..tokenization_utils import TruncationStrategy
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
@@ -50,6 +52,28 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _parse_and_tokenize(self, *args, truncation):
|
||||||
|
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
||||||
|
if isinstance(args[0], list):
|
||||||
|
assert (
|
||||||
|
self.tokenizer.pad_token_id is not None
|
||||||
|
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
||||||
|
args = ([prefix + arg for arg in args[0]],)
|
||||||
|
padding = True
|
||||||
|
|
||||||
|
elif isinstance(args[0], str):
|
||||||
|
args = (prefix + args[0],)
|
||||||
|
padding = False
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
|
||||||
|
)
|
||||||
|
inputs = super()._parse_and_tokenize(*args, padding=padding, truncation=truncation)
|
||||||
|
# This is produced by tokenizers but is an invalid generate kwargs
|
||||||
|
if "token_type_ids" in inputs:
|
||||||
|
del inputs["token_type_ids"]
|
||||||
|
return inputs
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@@ -88,53 +112,41 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||||
|
|
||||||
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
|
|
||||||
if isinstance(args[0], list):
|
|
||||||
assert (
|
|
||||||
self.tokenizer.pad_token_id is not None
|
|
||||||
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
|
|
||||||
args = ([prefix + arg for arg in args[0]],)
|
|
||||||
padding = True
|
|
||||||
|
|
||||||
elif isinstance(args[0], str):
|
|
||||||
args = (prefix + args[0],)
|
|
||||||
padding = False
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.device_placement():
|
with self.device_placement():
|
||||||
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)
|
inputs = self._parse_and_tokenize(*args, truncation=truncation)
|
||||||
|
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
|
||||||
|
|
||||||
if self.framework == "pt":
|
def _generate(
|
||||||
inputs = self.ensure_tensor_on_device(**inputs)
|
self, inputs, return_tensors: bool, return_text: bool, clean_up_tokenization_spaces: bool, generate_kwargs
|
||||||
input_length = inputs["input_ids"].shape[-1]
|
):
|
||||||
elif self.framework == "tf":
|
if self.framework == "pt":
|
||||||
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
|
input_length = inputs["input_ids"].shape[-1]
|
||||||
|
elif self.framework == "tf":
|
||||||
|
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
|
||||||
|
|
||||||
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
|
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
|
||||||
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
|
||||||
self.check_inputs(input_length, min_length, max_length)
|
self.check_inputs(input_length, min_length, max_length)
|
||||||
|
|
||||||
generations = self.model.generate(
|
generate_kwargs.update(inputs)
|
||||||
inputs["input_ids"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
generations = self.model.generate(
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
for generation in generations:
|
for generation in generations:
|
||||||
record = {}
|
record = {}
|
||||||
if return_tensors:
|
if return_tensors:
|
||||||
record[f"{self.return_name}_token_ids"] = generation
|
record[f"{self.return_name}_token_ids"] = generation
|
||||||
if return_text:
|
if return_text:
|
||||||
record[f"{self.return_name}_text"] = self.tokenizer.decode(
|
record[f"{self.return_name}_text"] = self.tokenizer.decode(
|
||||||
generation,
|
generation,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
)
|
)
|
||||||
results.append(record)
|
results.append(record)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
@@ -226,6 +238,23 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
|||||||
|
|
||||||
# Used in the return key of the pipeline.
|
# Used in the return key of the pipeline.
|
||||||
return_name = "translation"
|
return_name = "translation"
|
||||||
|
src_lang: Optional[str] = None
|
||||||
|
tgt_lang: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self, *args, src_lang=None, tgt_lang=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
if src_lang is not None:
|
||||||
|
self.src_lang = src_lang
|
||||||
|
if tgt_lang is not None:
|
||||||
|
self.tgt_lang = tgt_lang
|
||||||
|
if src_lang is None and tgt_lang is None:
|
||||||
|
# Backward compatibility, direct arguments use is preferred.
|
||||||
|
task = kwargs.get("task", "")
|
||||||
|
items = task.split("_")
|
||||||
|
if task and len(items) == 4:
|
||||||
|
# translation, XX, to YY
|
||||||
|
self.src_lang = items[1]
|
||||||
|
self.tgt_lang = items[3]
|
||||||
|
|
||||||
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
def check_inputs(self, input_length: int, min_length: int, max_length: int):
|
||||||
if input_length > 0.9 * max_length:
|
if input_length > 0.9 * max_length:
|
||||||
@@ -233,8 +262,27 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
|||||||
f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
|
f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
|
||||||
"increasing your max_length manually, e.g. translator('...', max_length=400)"
|
"increasing your max_length manually, e.g. translator('...', max_length=400)"
|
||||||
)
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation):
|
||||||
|
if getattr(self.tokenizer, "_build_translation_inputs", None):
|
||||||
|
return self.tokenizer._build_translation_inputs(
|
||||||
|
*args, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super()._parse_and_tokenize(*args, truncation=truncation)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
return_tensors=False,
|
||||||
|
return_text=True,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
|
||||||
|
src_lang=None,
|
||||||
|
tgt_lang=None,
|
||||||
|
**generate_kwargs
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Translate the text(s) given as inputs.
|
Translate the text(s) given as inputs.
|
||||||
|
|
||||||
@@ -247,6 +295,12 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
|||||||
Whether or not to include the decoded texts in the outputs.
|
Whether or not to include the decoded texts in the outputs.
|
||||||
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to clean up the potential extra spaces in the text output.
|
Whether or not to clean up the potential extra spaces in the text output.
|
||||||
|
src_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
|
The language of the input. Might be required for multilingual models. Will not have any effect for
|
||||||
|
single pair translation models
|
||||||
|
tgt_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
|
The language of the desired output. Might be required for multilingual models. Will not have any effect
|
||||||
|
for single pair translation models
|
||||||
generate_kwargs:
|
generate_kwargs:
|
||||||
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
||||||
corresponding to your framework `here <./model.html#generative-models>`__).
|
corresponding to your framework `here <./model.html#generative-models>`__).
|
||||||
@@ -258,4 +312,10 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
|||||||
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||||
-- The token ids of the translation.
|
-- The token ids of the translation.
|
||||||
"""
|
"""
|
||||||
return super().__call__(*args, **kwargs)
|
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||||
|
src_lang = src_lang if src_lang is not None else self.src_lang
|
||||||
|
tgt_lang = tgt_lang if tgt_lang is not None else self.tgt_lang
|
||||||
|
|
||||||
|
with self.device_placement():
|
||||||
|
inputs = self._parse_and_tokenize(*args, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang)
|
||||||
|
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
|
||||||
|
|||||||
@@ -361,6 +361,9 @@ if is_torch_available():
|
|||||||
else:
|
else:
|
||||||
torch_device = None
|
torch_device = None
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
def require_torch_gpu(test_case):
|
def require_torch_gpu(test_case):
|
||||||
"""Decorator marking a test that requires CUDA and PyTorch. """
|
"""Decorator marking a test that requires CUDA and PyTorch. """
|
||||||
@@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False
|
|||||||
raise RuntimeError(f"'{cmd_str}' produced no output.")
|
raise RuntimeError(f"'{cmd_str}' produced no output.")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def nested_simplify(obj, decimals=3):
|
||||||
|
"""
|
||||||
|
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
|
||||||
|
within tests.
|
||||||
|
"""
|
||||||
|
from transformers.tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [nested_simplify(item, decimals) for item in obj]
|
||||||
|
elif isinstance(obj, (dict, BatchEncoding)):
|
||||||
|
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, (str, int)):
|
||||||
|
return obj
|
||||||
|
elif is_torch_available() and isinstance(obj, torch.Tensor):
|
||||||
|
return nested_simplify(obj.tolist())
|
||||||
|
elif is_tf_available() and tf.is_tensor(obj):
|
||||||
|
return nested_simplify(obj.numpy().tolist())
|
||||||
|
elif isinstance(obj, float):
|
||||||
|
return round(obj, decimals)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Not supported: {type(obj)}")
|
||||||
|
|||||||
@@ -17,11 +17,15 @@ import unittest
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from transformers.testing_utils import is_pipeline_test, require_torch, slow
|
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||||
pipeline_task = "translation_en_to_de"
|
pipeline_task = "translation_en_to_de"
|
||||||
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
||||||
@@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
|||||||
pipeline(task="translation_cn_to_ar")
|
pipeline(task="translation_cn_to_ar")
|
||||||
|
|
||||||
# but we do for this one
|
# but we do for this one
|
||||||
pipeline(task="translation_en_to_de")
|
translator = pipeline(task="translation_en_to_de")
|
||||||
|
self.assertEquals(translator.src_lang, "en")
|
||||||
|
self.assertEquals(translator.tgt_lang, "de")
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_multilingual_translation(self):
|
||||||
|
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||||
|
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||||
|
|
||||||
|
translator = pipeline(task="translation", model=model, tokenizer=tokenizer)
|
||||||
|
# Missing src_lang, tgt_lang
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
translator("This is a test")
|
||||||
|
|
||||||
|
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||||
|
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
|
||||||
|
|
||||||
|
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN")
|
||||||
|
self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])
|
||||||
|
|
||||||
|
# src_lang, tgt_lang can be defined at pipeline call time
|
||||||
|
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR")
|
||||||
|
outputs = translator("This is a test")
|
||||||
|
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_translation_on_odd_language(self):
|
def test_translation_on_odd_language(self):
|
||||||
model = "patrickvonplaten/t5-tiny-random"
|
model = "patrickvonplaten/t5-tiny-random"
|
||||||
pipeline(task="translation_cn_to_ar", model=model)
|
translator = pipeline(task="translation_cn_to_ar", model=model)
|
||||||
|
self.assertEquals(translator.src_lang, "cn")
|
||||||
|
self.assertEquals(translator.tgt_lang, "ar")
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_translation_default_language_selection(self):
|
def test_translation_default_language_selection(self):
|
||||||
@@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
|||||||
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
||||||
nlp = pipeline(task="translation", model=model)
|
nlp = pipeline(task="translation", model=model)
|
||||||
self.assertEqual(nlp.task, "translation_en_to_de")
|
self.assertEqual(nlp.task, "translation_en_to_de")
|
||||||
|
self.assertEquals(nlp.src_lang, "en")
|
||||||
|
self.assertEquals(nlp.tgt_lang, "de")
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_translation_with_no_language_no_model_fails(self):
|
def test_translation_with_no_language_no_model_fails(self):
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from shutil import copyfile
|
|||||||
|
|
||||||
from transformers import M2M100Tokenizer, is_torch_available
|
from transformers import M2M100Tokenizer, is_torch_available
|
||||||
from transformers.file_utils import is_sentencepiece_available
|
from transformers.file_utils import is_sentencepiece_available
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||||
|
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_tokenizer_translation(self):
|
||||||
|
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(inputs),
|
||||||
|
{
|
||||||
|
# en_XX, A, test, EOS
|
||||||
|
"input_ids": [[128022, 58, 4183, 2]],
|
||||||
|
"attention_mask": [[1, 1, 1, 1]],
|
||||||
|
# ar_AR
|
||||||
|
"forced_bos_token_id": 128006,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
|
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
@@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_tokenizer_translation(self):
|
||||||
|
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(inputs),
|
||||||
|
{
|
||||||
|
# A, test, EOS, en_XX
|
||||||
|
"input_ids": [[62, 3034, 2, 250004]],
|
||||||
|
"attention_mask": [[1, 1, 1, 1]],
|
||||||
|
# ar_AR
|
||||||
|
"forced_bos_token_id": 250001,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
|
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_tokenizer_translation(self):
|
||||||
|
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(inputs),
|
||||||
|
{
|
||||||
|
# en_XX, A, test, EOS
|
||||||
|
"input_ids": [[250004, 62, 3034, 2]],
|
||||||
|
"attention_mask": [[1, 1, 1, 1]],
|
||||||
|
# ar_AR
|
||||||
|
"forced_bos_token_id": 250001,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user