Fixing the "translation", "translation_XX_to_YY" pipelines. (#7975)

* Actually make the "translation", "translation_XX_to_YY" task behave correctly.

Background:
- Currently "translation_cn_to_ar" does not work. (only 3 pairs are
supported)
- Some models, contain in their config the correct values for the (src,
tgt) pair they can translate. It's usually just one pair, and we can
infer it automatically from the `model.config.task_specific_params`. If
it's not defined we can still probably load the TranslationPipeline
nevertheless.

Proposed fix:
- A simplified version of what could become more general which is
a `parametrized` task. "translation" + (src, tgt) in this instance
it what we need in the general case. The way we go about it for now
is simply parsing "translation_XX_to_YY". If cases of parametrized task arise
we should preferably go in something closer to what `datasets` propose
which is having a secondary argument `task_options`? that will be close
to what that task requires.
- Should be backward compatible in all cases for instance
`pipeline(task="translation_en_to_de") should work out of the box.
- Should provide a warning when a specific translation pair has been
selected on behalf of the user using
`model.config.task_specific_params`.

* Update src/transformers/pipelines.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Nicolas Patry
2020-10-22 17:16:21 +02:00
committed by GitHub
parent 901e9b8eda
commit 18ce6b8ff3
2 changed files with 111 additions and 21 deletions

View File

@@ -1,6 +1,8 @@
import unittest
from typing import Iterable, List, Optional
import pytest
from transformers import pipeline
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow, torch_device
@@ -392,6 +394,33 @@ class MonoColumnInputTestCase(unittest.TestCase):
invalid_inputs,
)
@require_torch
@slow
def test_default_translations(self):
# We don't provide a default for this pair
with self.assertRaises(ValueError):
pipeline(task="translation_cn_to_ar")
# but we do for this one
pipeline(task="translation_en_to_de")
@require_torch
def test_translation_on_odd_language(self):
model = TRANSLATION_FINETUNED_MODELS[0][0]
pipeline(task="translation_cn_to_ar", model=model)
@require_torch
def test_translation_default_language_selection(self):
model = TRANSLATION_FINETUNED_MODELS[0][0]
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
nlp = pipeline(task="translation", model=model)
self.assertEqual(nlp.task, "translation_en_to_de")
@require_torch
def test_translation_with_no_language_no_model_fails(self):
with self.assertRaises(ValueError):
pipeline(task="translation")
@require_tf
@slow
def test_tf_translation(self):