Enabling multilingual models for translation pipelines. (#10536)
* [WIP] Enabling multilingual models for translation pipelines. * decoder_input_ids -> forced_bos_token_id * Improve docstring. * Rebase * Fixing 2 bugs - Type token_ids coming from `_parse_and_tokenize` - Wrong index from tgt_lang. * Fixing black version. * Adding tests for _build_translation_inputs and add them for all tokenizers. * Mbart actually puts the lang code at the end. * Fixing m2m100. * Adding TF support to `deep_round`. * Update src/transformers/pipelines/text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding one line comment. * Fixing M2M100 `_build_translation_input_ids`, and fix the call site. * Fixing tests + deep_round -> nested_simplify Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# en_XX, A, test, EOS
|
||||
"input_ids": [[250004, 62, 3034, 2]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 250001,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user