Enabling multilingual models for translation pipelines. (#10536)

* [WIP] Enabling multilingual models for translation pipelines.

* decoder_input_ids -> forced_bos_token_id

* Improve docstring.

* Rebase

* Fixing 2 bugs

- Type token_ids coming from `_parse_and_tokenize`
- Wrong index from tgt_lang.

* Fixing black version.

* Adding tests for _build_translation_inputs and add them for all
tokenizers.

* Mbart actually puts the lang code at the end.

* Fixing m2m100.

* Adding TF support to `deep_round`.

* Update src/transformers/pipelines/text2text_generation.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding one line comment.

* Fixing M2M100 `_build_translation_input_ids`, and fix the call site.

* Fixing tests + deep_round -> nested_simplify

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2021-04-16 11:31:35 +02:00
committed by GitHub
parent 5254220e7f
commit 92970c0cb9
12 changed files with 268 additions and 52 deletions

View File

@@ -17,11 +17,15 @@ import unittest
import pytest
from transformers import pipeline
from transformers.testing_utils import is_pipeline_test, require_torch, slow
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "translation_en_to_de"
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
@@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
pipeline(task="translation_cn_to_ar")
# but we do for this one
pipeline(task="translation_en_to_de")
translator = pipeline(task="translation_en_to_de")
self.assertEquals(translator.src_lang, "en")
self.assertEquals(translator.tgt_lang, "de")
@require_torch
@slow
def test_multilingual_translation(self):
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translator = pipeline(task="translation", model=model, tokenizer=tokenizer)
# Missing src_lang, tgt_lang
with self.assertRaises(ValueError):
translator("This is a test")
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN")
self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])
# src_lang, tgt_lang can be defined at pipeline call time
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR")
outputs = translator("This is a test")
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
@require_torch
def test_translation_on_odd_language(self):
model = "patrickvonplaten/t5-tiny-random"
pipeline(task="translation_cn_to_ar", model=model)
translator = pipeline(task="translation_cn_to_ar", model=model)
self.assertEquals(translator.src_lang, "cn")
self.assertEquals(translator.tgt_lang, "ar")
@require_torch
def test_translation_default_language_selection(self):
@@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
nlp = pipeline(task="translation", model=model)
self.assertEqual(nlp.task, "translation_en_to_de")
self.assertEquals(nlp.src_lang, "en")
self.assertEquals(nlp.tgt_lang, "de")
@require_torch
def test_translation_with_no_language_no_model_fails(self):

View File

@@ -20,7 +20,7 @@ from shutil import copyfile
from transformers import M2M100Tokenizer, is_torch_available
from transformers.file_utils import is_sentencepiece_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
if is_sentencepiece_available():
@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
self.assertEqual(
nested_simplify(inputs),
{
# en_XX, A, test, EOS
"input_ids": [[128022, 58, 4183, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 128006,
},
)

View File

@@ -17,7 +17,7 @@ import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin
@@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[62, 3034, 2, 250004]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 250001,
},
)

View File

@@ -17,7 +17,7 @@ import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin
@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(
nested_simplify(inputs),
{
# en_XX, A, test, EOS
"input_ids": [[250004, 62, 3034, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 250001,
},
)