From a3f96f366a49bbe2cbdeaebd2e32ccdc1260a1d6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 27 Aug 2021 12:26:17 +0200 Subject: [PATCH] Moving `translation` pipeline to new testing scheme. (#13297) * Moving `translation` pipeline to new testing scheme. * Update tokenization mbart tests. --- .../models/mbart/tokenization_mbart.py | 6 +- .../models/mbart/tokenization_mbart_fast.py | 6 +- tests/test_pipelines_translation.py | 94 +++++++++++++++---- tests/test_tokenization_mbart.py | 4 +- 4 files changed, 85 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/mbart/tokenization_mbart.py b/src/transformers/models/mbart/tokenization_mbart.py index 576e62b265..cf3bfab08f 100644 --- a/src/transformers/models/mbart/tokenization_mbart.py +++ b/src/transformers/models/mbart/tokenization_mbart.py @@ -201,12 +201,14 @@ class MBartTokenizer(XLMRobertaTokenizer): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): """Used by translation pipeline, to prepare inputs for the generate function""" if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") self.src_lang = src_lang - inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs diff --git a/src/transformers/models/mbart/tokenization_mbart_fast.py b/src/transformers/models/mbart/tokenization_mbart_fast.py index 94f5eda640..b135ecba4c 100644 --- a/src/transformers/models/mbart/tokenization_mbart_fast.py +++ b/src/transformers/models/mbart/tokenization_mbart_fast.py @@ -186,12 +186,14 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): # We don't expect to process pairs, but leave the pair logic for API consistency return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens - def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): + def _build_translation_inputs( + self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs + ): """Used by translation pipeline, to prepare inputs for the generate function""" if src_lang is None or tgt_lang is None: raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") self.src_lang = src_lang - inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) + inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) inputs["forced_bos_token_id"] = tgt_lang_id return inputs diff --git a/tests/test_pipelines_translation.py b/tests/test_pipelines_translation.py index 4456410d6f..821f0db44b 100644 --- a/tests/test_pipelines_translation.py +++ b/tests/test_pipelines_translation.py @@ -16,31 +16,85 @@ import unittest import pytest -from transformers import pipeline -from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow +from transformers import ( + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MBart50TokenizerFast, + MBartForConditionalGeneration, + TranslationPipeline, + pipeline, +) +from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, slow -from .test_pipelines_common import MonoInputPipelineCommonMixin +from .test_pipelines_common import ANY, PipelineTestCaseMeta -if is_torch_available(): - from transformers.models.mbart import MBartForConditionalGeneration - from transformers.models.mbart50 import MBart50TokenizerFast +@is_pipeline_test +class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING + def run_pipeline_test(self, model, tokenizer, feature_extractor): + translator = TranslationPipeline(model=model, tokenizer=tokenizer) + try: + outputs = translator("Some string") + except ValueError: + # Triggered by m2m langages + src_lang, tgt_lang = list(translator.tokenizer.lang_code_to_id.keys())[:2] + outputs = translator("Some string", src_lang=src_lang, tgt_lang=tgt_lang) + self.assertEqual(outputs, [{"translation_text": ANY(str)}]) -class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): - pipeline_task = "translation_en_to_de" - small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator - large_models = [None] # Models tested with the @slow decorator - invalid_inputs = [4, ""] - mandatory_keys = ["translation_text"] + @require_torch + def test_small_model_pt(self): + translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="pt") + outputs = translator("This is a test string", max_length=20) + self.assertEqual( + outputs, + [ + { + "translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide" + } + ], + ) + @require_tf + def test_small_model_tf(self): + translator = pipeline("translation_en_to_ro", model="patrickvonplaten/t5-tiny-random", framework="tf") + outputs = translator("This is a test string", max_length=20) + self.assertEqual( + outputs, + [ + { + "translation_text": "Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide Beide" + } + ], + ) -class TranslationEnToRoPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): - pipeline_task = "translation_en_to_ro" - small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator - large_models = [None] # Models tested with the @slow decorator - invalid_inputs = [4, ""] - mandatory_keys = ["translation_text"] + @require_torch + def test_en_to_de_pt(self): + translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="pt") + outputs = translator("This is a test string", max_length=20) + self.assertEqual( + outputs, + [ + { + "translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine" + } + ], + ) + + @require_tf + def test_en_to_de_tf(self): + translator = pipeline("translation_en_to_de", model="patrickvonplaten/t5-tiny-random", framework="tf") + outputs = translator("This is a test string", max_length=20) + self.assertEqual( + outputs, + [ + { + "translation_text": "monoton monoton monoton monoton monoton monoton monoton monoton monoton monoton urine urine urine urine urine urine urine urine urine" + } + ], + ) @is_pipeline_test @@ -92,8 +146,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"): translator = pipeline(task="translation", model=model) self.assertEqual(translator.task, "translation_en_to_de") - self.assertEquals(translator.src_lang, "en") - self.assertEquals(translator.tgt_lang, "de") + self.assertEqual(translator.src_lang, "en") + self.assertEqual(translator.tgt_lang, "de") @require_torch def test_translation_with_no_language_no_model_fails(self): diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index 640aec60fd..02a87d285c 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -235,7 +235,9 @@ class MBartEnroIntegrationTest(unittest.TestCase): @require_torch def test_tokenizer_translation(self): - inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR") + inputs = self.tokenizer._build_translation_inputs( + "A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR" + ) self.assertEqual( nested_simplify(inputs),