Enabling multilingual models for translation pipelines. (#10536)
* [WIP] Enabling multilingual models for translation pipelines. * decoder_input_ids -> forced_bos_token_id * Improve docstring. * Rebase * Fixing 2 bugs - Type token_ids coming from `_parse_and_tokenize` - Wrong index from tgt_lang. * Fixing black version. * Adding tests for _build_translation_inputs and add them for all tokenizers. * Mbart actually puts the lang code at the end. * Fixing m2m100. * Adding TF support to `deep_round`. * Update src/transformers/pipelines/text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding one line comment. * Fixing M2M100 `_build_translation_input_ids`, and fix the call site. * Fixing tests + deep_round -> nested_simplify Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from shutil import copyfile
|
||||
|
||||
from transformers import M2M100Tokenizer, is_torch_available
|
||||
from transformers.file_utils import is_sentencepiece_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# en_XX, A, test, EOS
|
||||
"input_ids": [[128022, 58, 4183, 2]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 128006,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user