Fixing the "translation", "translation_XX_to_YY" pipelines. (#7975)
* Actually make the "translation", "translation_XX_to_YY" task behave correctly. Background: - Currently "translation_cn_to_ar" does not work. (only 3 pairs are supported) - Some models, contain in their config the correct values for the (src, tgt) pair they can translate. It's usually just one pair, and we can infer it automatically from the `model.config.task_specific_params`. If it's not defined we can still probably load the TranslationPipeline nevertheless. Proposed fix: - A simplified version of what could become more general which is a `parametrized` task. "translation" + (src, tgt) in this instance it what we need in the general case. The way we go about it for now is simply parsing "translation_XX_to_YY". If cases of parametrized task arise we should preferably go in something closer to what `datasets` propose which is having a secondary argument `task_options`? that will be close to what that task requires. - Should be backward compatible in all cases for instance `pipeline(task="translation_en_to_de") should work out of the box. - Should provide a warning when a specific translation pair has been selected on behalf of the user using `model.config.task_specific_params`. * Update src/transformers/pipelines.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import unittest
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Conversation, DefaultArgumentHandler, Pipeline
|
||||
from transformers.testing_utils import require_tf, require_tokenizers, require_torch, slow, torch_device
|
||||
@@ -392,6 +394,33 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
invalid_inputs,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_default_translations(self):
|
||||
# We don't provide a default for this pair
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="translation_cn_to_ar")
|
||||
|
||||
# but we do for this one
|
||||
pipeline(task="translation_en_to_de")
|
||||
|
||||
@require_torch
|
||||
def test_translation_on_odd_language(self):
|
||||
model = TRANSLATION_FINETUNED_MODELS[0][0]
|
||||
pipeline(task="translation_cn_to_ar", model=model)
|
||||
|
||||
@require_torch
|
||||
def test_translation_default_language_selection(self):
|
||||
model = TRANSLATION_FINETUNED_MODELS[0][0]
|
||||
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
|
||||
nlp = pipeline(task="translation", model=model)
|
||||
self.assertEqual(nlp.task, "translation_en_to_de")
|
||||
|
||||
@require_torch
|
||||
def test_translation_with_no_language_no_model_fails(self):
|
||||
with self.assertRaises(ValueError):
|
||||
pipeline(task="translation")
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_translation(self):
|
||||
|
||||
Reference in New Issue
Block a user