Enabling multilingual models for translation pipelines. (#10536)
* [WIP] Enabling multilingual models for translation pipelines. * decoder_input_ids -> forced_bos_token_id * Improve docstring. * Rebase * Fixing 2 bugs - Type token_ids coming from `_parse_and_tokenize` - Wrong index from tgt_lang. * Fixing black version. * Adding tests for _build_translation_inputs and add them for all tokenizers. * Mbart actually puts the lang code at the end. * Fixing m2m100. * Adding TF support to `deep_round`. * Update src/transformers/pipelines/text2text_generation.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding one line comment. * Fixing M2M100 `_build_translation_input_ids`, and fix the call site. * Fixing tests + deep_round -> nested_simplify Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,11 +17,15 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch, slow
|
||||
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
|
||||
|
||||
|
||||
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "translation_en_to_de"
|
||||
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
|
||||
@@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
||||
pipeline(task="translation_cn_to_ar")
|
||||
|
||||
# but we do for this one
|
||||
pipeline(task="translation_en_to_de")
|
||||
translator = pipeline(task="translation_en_to_de")
|
||||
self.assertEquals(translator.src_lang, "en")
|
||||
self.assertEquals(translator.tgt_lang, "de")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_multilingual_translation(self):
|
||||
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
||||
|
||||
translator = pipeline(task="translation", model=model, tokenizer=tokenizer)
|
||||
# Missing src_lang, tgt_lang
|
||||
with self.assertRaises(ValueError):
|
||||
translator("This is a test")
|
||||
|
||||
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
|
||||
|
||||
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN")
|
||||
self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])
|
||||
|
||||
# src_lang, tgt_lang can be defined at pipeline call time
|
||||
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR")
|
||||
outputs = translator("This is a test")
|
||||
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
|
||||
|
||||
@require_torch
|
||||
def test_translation_on_odd_language(self):
|
||||
model = "patrickvonplaten/t5-tiny-random"
|
||||
pipeline(task="translation_cn_to_ar", model=model)
|
||||
translator = pipeline(task="translation_cn_to_ar", model=model)
|
||||
self.assertEquals(translator.src_lang, "cn")
|
||||
self.assertEquals(translator.tgt_lang, "ar")
|
||||
|
||||
@require_torch
|
||||
def test_translation_default_language_selection(self):
|
||||
@@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
|
||||
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
||||
nlp = pipeline(task="translation", model=model)
|
||||
self.assertEqual(nlp.task, "translation_en_to_de")
|
||||
self.assertEquals(nlp.src_lang, "en")
|
||||
self.assertEquals(nlp.tgt_lang, "de")
|
||||
|
||||
@require_torch
|
||||
def test_translation_with_no_language_no_model_fails(self):
|
||||
|
||||
@@ -20,7 +20,7 @@ from shutil import copyfile
|
||||
|
||||
from transformers import M2M100Tokenizer, is_torch_available
|
||||
from transformers.file_utils import is_sentencepiece_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# en_XX, A, test, EOS
|
||||
"input_ids": [[128022, 58, 4183, 2]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 128006,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# A, test, EOS, en_XX
|
||||
"input_ids": [[62, 3034, 2, 250004]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 250001,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(batch.input_ids.shape[1], 3)
|
||||
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_translation(self):
|
||||
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
# en_XX, A, test, EOS
|
||||
"input_ids": [[250004, 62, 3034, 2]],
|
||||
"attention_mask": [[1, 1, 1, 1]],
|
||||
# ar_AR
|
||||
"forced_bos_token_id": 250001,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user