Enabling multilingual models for translation pipelines. (#10536)

* [WIP] Enabling multilingual models for translation pipelines.

* decoder_input_ids -> forced_bos_token_id

* Improve docstring.

* Rebase

* Fixing 2 bugs

- Type token_ids coming from `_parse_and_tokenize`
- Wrong index from tgt_lang.

* Fixing black version.

* Adding tests for _build_translation_inputs and add them for all
tokenizers.

* Mbart actually puts the lang code at the end.

* Fixing m2m100.

* Adding TF support to `deep_round`.

* Update src/transformers/pipelines/text2text_generation.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding one line comment.

* Fixing M2M100 `_build_translation_input_ids`, and fix the call site.

* Fixing tests + deep_round -> nested_simplify

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2021-04-16 11:31:35 +02:00
committed by GitHub
parent 5254220e7f
commit 92970c0cb9
12 changed files with 268 additions and 52 deletions

View File

@@ -361,6 +361,9 @@ if is_torch_available():
else:
torch_device = None
if is_tf_available():
import tensorflow as tf
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch. """
@@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False
raise RuntimeError(f"'{cmd_str}' produced no output.")
return result
def nested_simplify(obj, decimals=3):
"""
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
within tests.
"""
from transformers.tokenization_utils import BatchEncoding
if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, (dict, BatchEncoding)):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int)):
return obj
elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist())
elif is_tf_available() and tf.is_tensor(obj):
return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float):
return round(obj, decimals)
else:
raise Exception(f"Not supported: {type(obj)}")