From 5c153079e26b049b5870344637e754cc09f6bc39 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 10 Nov 2021 10:18:35 +0100 Subject: [PATCH] Adding some quality of life for `pipeline` function. (#14322) * Adding some quality of life for `pipeline` function. * Update docs/source/main_classes/pipelines.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Improve the tests. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/main_classes/pipelines.rst | 37 +++++++++++++++++- src/transformers/pipelines/__init__.py | 52 ++++++++++++++++++++++++-- tests/test_pipelines_common.py | 25 +++++++++++++ 3 files changed, 109 insertions(+), 5 deletions(-) diff --git a/docs/source/main_classes/pipelines.rst b/docs/source/main_classes/pipelines.rst index 146c504861..bd78e9e63a 100644 --- a/docs/source/main_classes/pipelines.rst +++ b/docs/source/main_classes/pipelines.rst @@ -45,7 +45,7 @@ The pipeline abstraction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other -pipeline but requires an additional argument which is the `task`. +pipeline but can provide additional quality of life. Simple call on one item: @@ -55,6 +55,15 @@ Simple call on one item: >>> pipe("This restaurant is awesome") [{'label': 'POSITIVE', 'score': 0.9998743534088135}] +If you want to use a specific model from the `hub `__ you can ignore the task if the model on +the hub already defines it: + +.. code-block:: + + >>> pipe = pipeline(model="roberta-large-mnli") + >>> pipe("This restaurant is awesome") + [{'label': 'POSITIVE', 'score': 0.9998743534088135}] + To call a pipeline on many items, you can either call with a `list`. .. code-block:: @@ -226,6 +235,32 @@ For users, a rule of thumb is: - The larger the GPU the more likely batching is going to be more interesting - As soon as you enable batching, make sure you can handle OOMs nicely. +Pipeline custom code +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to override a specific pipeline. + +Don't hesitate to create an issue for your task at hand, the goal of the pipeline is to be easy to use and support most +cases, so :obj:`transformers` could maybe support your use case. + + +If you want to try simply you can: + +- Subclass your pipeline of choice + +.. code-block:: + + class MyPipeline(TextClassificationPipeline): + def postprocess(...): + ... + scores = scores * 100 + ... + + my_pipeline = MyPipeline(model=model, tokenizer=tokenizer, ...) + # or if you use `pipeline` function, then: + my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline) + +That should enable you to do all the custom code you want. Implementing a pipeline diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index c7f17bc5a4..afbd41e615 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -2,6 +2,9 @@ # There's no way to ignore "F401 '...' imported but unused" warnings in this # module, but to preserve other warnings. So, don't check this module at all. +import io +import json + # coding=utf-8 # Copyright 2018 The HuggingFace Inc. team. # @@ -21,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from ..configuration_utils import PretrainedConfig from ..feature_extraction_utils import PreTrainedFeatureExtractor -from ..file_utils import is_tf_available, is_torch_available +from ..file_utils import http_get, is_tf_available, is_torch_available from ..models.auto.configuration_auto import AutoConfig from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer @@ -248,6 +251,29 @@ SUPPORTED_TASKS = { } +def get_task(model: str, use_auth_token: Optional[str] = None) -> str: + tmp = io.BytesIO() + headers = {} + if use_auth_token: + headers["Authorization"] = f"Bearer {use_auth_token}" + + try: + http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers) + tmp.seek(0) + body = tmp.read() + data = json.loads(body) + except Exception as e: + raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}") + if "pipeline_tag" not in data: + raise RuntimeError( + f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically" + ) + if data.get("library_name", "transformers") != "transformers": + raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers") + task = data["pipeline_tag"] + return task + + def check_task(task: str) -> Tuple[Dict, Any]: """ Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and @@ -299,7 +325,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: def pipeline( - task: str, + task: str = None, model: Optional = None, config: Optional[Union[str, PretrainedConfig]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, @@ -309,6 +335,7 @@ def pipeline( use_fast: bool = True, use_auth_token: Optional[Union[str, bool]] = None, model_kwargs: Dict[str, Any] = None, + pipeline_class: Optional[Any] = None, **kwargs ) -> Pipeline: """ @@ -422,6 +449,14 @@ def pipeline( """ if model_kwargs is None: model_kwargs = {} + + if task is None and model is None: + raise RuntimeError( + "Impossible to instantiate a pipeline without either a task or a model" + "being specified." + "Please provide a task class or a model" + ) + if model is None and tokenizer is not None: raise RuntimeError( "Impossible to instantiate a pipeline with tokenizer specified but not the model " @@ -435,9 +470,18 @@ def pipeline( "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor." ) + if task is None and model is not None: + if not isinstance(model, str): + raise RuntimeError( + "Inferring the task automatically requires to check the hub with a model_id defined as a `str`." + f"{model} is not a valid model_id." + ) + task = get_task(model, use_auth_token) + # Retrieve the task targeted_task, task_options = check_task(task) - task_class = targeted_task["impl"] + if pipeline_class is None: + pipeline_class = targeted_task["impl"] # Use default model/config/tokenizer for the task if no model is provided if model is None: @@ -549,4 +593,4 @@ def pipeline( if feature_extractor is not None: kwargs["feature_extractor"] = feature_extractor - return task_class(model=model, framework=framework, task=task, **kwargs) + return pipeline_class(model=model, framework=framework, task=task, **kwargs) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 9affebbe15..c005b26450 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -29,8 +29,10 @@ from transformers import ( AutoTokenizer, IBertConfig, RobertaConfig, + TextClassificationPipeline, pipeline, ) +from transformers.pipelines import get_task from transformers.pipelines.base import _pad from transformers.testing_utils import is_pipeline_test, require_torch @@ -261,6 +263,29 @@ class CommonPipelineTest(unittest.TestCase): for output in text_classifier(dataset): self.assertEqual(output, {"label": ANY(str), "score": ANY(float)}) + @require_torch + def test_check_task_auto_inference(self): + pipe = pipeline(model="Narsil/tiny-distilbert-sequence-classification") + + self.assertIsInstance(pipe, TextClassificationPipeline) + + @require_torch + def test_pipeline_override(self): + class MyPipeline(TextClassificationPipeline): + pass + + text_classifier = pipeline(model="Narsil/tiny-distilbert-sequence-classification", pipeline_class=MyPipeline) + + self.assertIsInstance(text_classifier, MyPipeline) + + def test_check_task(self): + task = get_task("gpt2") + self.assertEqual(task, "text-generation") + + with self.assertRaises(RuntimeError): + # Wrong framework + get_task("espnet/siddhana_slurp_entity_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best") + @is_pipeline_test class PipelinePadTest(unittest.TestCase):