Enabling multilingual models for translation pipelines. (#10536)
* [WIP] Enabling multilingual models for translation pipelines. * decoder_input_ids -> forced_bos_token_id * Improve docstring. * Rebase * Fixing 2 bugs - Type token_ids coming from `_parse_and_tokenize` - Wrong index from tgt_lang. * Fixing black version. * Adding tests for _build_translation_inputs and add them for all tokenizers. * Mbart actually puts the lang code at the end. * Fixing m2m100. * Adding TF support to `deep_round`. * Update src/transformers/pipelines/text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding one line comment. * Fixing M2M100 `_build_translation_input_ids`, and fix the call site. * Fixing tests + deep_round -> nested_simplify Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -361,6 +361,9 @@ if is_torch_available():
|
||||
else:
|
||||
torch_device = None
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch. """
|
||||
@@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False
|
||||
raise RuntimeError(f"'{cmd_str}' produced no output.")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def nested_simplify(obj, decimals=3):
|
||||
"""
|
||||
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
|
||||
within tests.
|
||||
"""
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
|
||||
if isinstance(obj, list):
|
||||
return [nested_simplify(item, decimals) for item in obj]
|
||||
elif isinstance(obj, (dict, BatchEncoding)):
|
||||
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
|
||||
elif isinstance(obj, (str, int)):
|
||||
return obj
|
||||
elif is_torch_available() and isinstance(obj, torch.Tensor):
|
||||
return nested_simplify(obj.tolist())
|
||||
elif is_tf_available() and tf.is_tensor(obj):
|
||||
return nested_simplify(obj.numpy().tolist())
|
||||
elif isinstance(obj, float):
|
||||
return round(obj, decimals)
|
||||
else:
|
||||
raise Exception(f"Not supported: {type(obj)}")
|
||||
|
||||
Reference in New Issue
Block a user