From 0911b6bd86b39d55ddeae42fbecef75a1244ea85 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 15 Oct 2020 09:42:07 +0200 Subject: [PATCH] Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. (#7728) * Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. * Actually changing the default resolution order to account for model defaults Adding a new tests for each pipeline to check that pipeline(task) works too without manually adding the framework too. --- src/transformers/pipelines.py | 62 +++++++++++++++++++++++++++-------- tests/test_pipelines.py | 11 +++++++ 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 9edc6380cd..33e7efaea6 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -85,31 +85,63 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def get_framework(model=None): +def get_framework(model): """ Select framework (TensorFlow or PyTorch) to use. Args: - model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`): + model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`): If both frameworks are installed, picks the one corresponding to the model passed (either a model class or the model name). If no specific model is provided, defaults to using PyTorch. """ - if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str): - # Both framework are available but the user supplied a model class instance. - # Try to guess which framework to use from the model classname - framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" - elif not is_tf_available() and not is_torch_available(): + if not is_tf_available() and not is_torch_available(): raise RuntimeError( "At least one of TensorFlow 2.0 or PyTorch should be installed. " "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " "To install PyTorch, read the instructions at https://pytorch.org/." ) - else: - # framework = 'tf' if is_tf_available() else 'pt' - framework = "pt" if is_torch_available() else "tf" + if isinstance(model, str): + if is_torch_available() and not is_tf_available(): + model = AutoModel.from_pretrained(model) + elif is_tf_available() and not is_torch_available(): + model = TFAutoModel.from_pretrained(model) + else: + try: + model = AutoModel.from_pretrained(model) + except OSError: + model = TFAutoModel.from_pretrained(model) + + framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" return framework +def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str: + """ + Select a default model to use for a given task. Defaults to pytorch if ambiguous. + + Args: + targeted_task (:obj:`Dict` ): + Dictionnary representing the given task, that should contain default models + + framework (:obj:`str`, None) + "pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet. + + Returns + + :obj:`str` The model string representing the default model for this pipeline + """ + if is_torch_available() and not is_tf_available(): + framework = "pt" + elif is_tf_available() and not is_torch_available(): + framework = "tf" + + default_models = targeted_task["default"]["model"] + if framework is None: + framework = "pt" + + return default_models[framework] + + class PipelineException(Exception): """ Raised by a :class:`~transformers.Pipeline` when handling __call__. @@ -2685,14 +2717,16 @@ def pipeline( if task not in SUPPORTED_TASKS: raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()))) - framework = framework or get_framework(model) - targeted_task = SUPPORTED_TASKS[task] - task_class, model_class = targeted_task["impl"], targeted_task[framework] # Use default model/config/tokenizer for the task if no model is provided if model is None: - model = targeted_task["default"]["model"][framework] + # At that point framework might still be undetermined + model = get_default_model(targeted_task, framework) + + framework = framework or get_framework(model) + + task_class, model_class = targeted_task["impl"], targeted_task[framework] # Try to infer tokenizer from model or config name (if provided as str) if tokenizer is None: diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 1815787507..e9ada812da 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -10,6 +10,7 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 VALID_INPUTS = ["A simple string", ["list of strings"]] NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"] +TF_NER_FINETUNED_MODELS = ["Narsil/small"] # xlnet-base-cased disabled for now, since it crashes TF2 FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"] @@ -804,6 +805,14 @@ class NerPipelineTests(unittest.TestCase): nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True) self._test_ner_pipeline(nlp, mandatory_keys) + @require_tf + def test_tf_only_ner(self): + mandatory_keys = {"entity", "word", "score"} + for model_name in TF_NER_FINETUNED_MODELS: + # We don't specificy framework='tf' but it gets detected automatically + nlp = pipeline(task="ner", model=model_name, tokenizer=model_name) + self._test_ner_pipeline(nlp, mandatory_keys) + class PipelineCommonTests(unittest.TestCase): pipelines = SUPPORTED_TASKS.keys() @@ -815,6 +824,7 @@ class PipelineCommonTests(unittest.TestCase): for task in self.pipelines: with self.subTest(msg="Testing TF defaults with TF and {}".format(task)): pipeline(task, framework="tf") + pipeline(task) @require_torch @slow @@ -823,3 +833,4 @@ class PipelineCommonTests(unittest.TestCase): for task in self.pipelines: with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)): pipeline(task, framework="pt") + pipeline(task)