Adding some quality of life for pipeline function. (#14322)
* Adding some quality of life for `pipeline` function. * Update docs/source/main_classes/pipelines.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Improve the tests. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -45,7 +45,7 @@ The pipeline abstraction
|
|||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other
|
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other
|
||||||
pipeline but requires an additional argument which is the `task`.
|
pipeline but can provide additional quality of life.
|
||||||
|
|
||||||
Simple call on one item:
|
Simple call on one item:
|
||||||
|
|
||||||
@@ -55,6 +55,15 @@ Simple call on one item:
|
|||||||
>>> pipe("This restaurant is awesome")
|
>>> pipe("This restaurant is awesome")
|
||||||
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
|
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
|
||||||
|
|
||||||
|
If you want to use a specific model from the `hub <https://huggingface.co>`__ you can ignore the task if the model on
|
||||||
|
the hub already defines it:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
>>> pipe = pipeline(model="roberta-large-mnli")
|
||||||
|
>>> pipe("This restaurant is awesome")
|
||||||
|
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
|
||||||
|
|
||||||
To call a pipeline on many items, you can either call with a `list`.
|
To call a pipeline on many items, you can either call with a `list`.
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
@@ -226,6 +235,32 @@ For users, a rule of thumb is:
|
|||||||
- The larger the GPU the more likely batching is going to be more interesting
|
- The larger the GPU the more likely batching is going to be more interesting
|
||||||
- As soon as you enable batching, make sure you can handle OOMs nicely.
|
- As soon as you enable batching, make sure you can handle OOMs nicely.
|
||||||
|
|
||||||
|
Pipeline custom code
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
If you want to override a specific pipeline.
|
||||||
|
|
||||||
|
Don't hesitate to create an issue for your task at hand, the goal of the pipeline is to be easy to use and support most
|
||||||
|
cases, so :obj:`transformers` could maybe support your use case.
|
||||||
|
|
||||||
|
|
||||||
|
If you want to try simply you can:
|
||||||
|
|
||||||
|
- Subclass your pipeline of choice
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
class MyPipeline(TextClassificationPipeline):
|
||||||
|
def postprocess(...):
|
||||||
|
...
|
||||||
|
scores = scores * 100
|
||||||
|
...
|
||||||
|
|
||||||
|
my_pipeline = MyPipeline(model=model, tokenizer=tokenizer, ...)
|
||||||
|
# or if you use `pipeline` function, then:
|
||||||
|
my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
|
||||||
|
|
||||||
|
That should enable you to do all the custom code you want.
|
||||||
|
|
||||||
|
|
||||||
Implementing a pipeline
|
Implementing a pipeline
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The HuggingFace Inc. team.
|
# Copyright 2018 The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
@@ -21,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||||
from ..file_utils import is_tf_available, is_torch_available
|
from ..file_utils import http_get, is_tf_available, is_torch_available
|
||||||
from ..models.auto.configuration_auto import AutoConfig
|
from ..models.auto.configuration_auto import AutoConfig
|
||||||
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||||
@@ -248,6 +251,29 @@ SUPPORTED_TASKS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
||||||
|
tmp = io.BytesIO()
|
||||||
|
headers = {}
|
||||||
|
if use_auth_token:
|
||||||
|
headers["Authorization"] = f"Bearer {use_auth_token}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers)
|
||||||
|
tmp.seek(0)
|
||||||
|
body = tmp.read()
|
||||||
|
data = json.loads(body)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
|
||||||
|
if "pipeline_tag" not in data:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
||||||
|
)
|
||||||
|
if data.get("library_name", "transformers") != "transformers":
|
||||||
|
raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers")
|
||||||
|
task = data["pipeline_tag"]
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
def check_task(task: str) -> Tuple[Dict, Any]:
|
def check_task(task: str) -> Tuple[Dict, Any]:
|
||||||
"""
|
"""
|
||||||
Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
|
Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
|
||||||
@@ -299,7 +325,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def pipeline(
|
def pipeline(
|
||||||
task: str,
|
task: str = None,
|
||||||
model: Optional = None,
|
model: Optional = None,
|
||||||
config: Optional[Union[str, PretrainedConfig]] = None,
|
config: Optional[Union[str, PretrainedConfig]] = None,
|
||||||
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
||||||
@@ -309,6 +335,7 @@ def pipeline(
|
|||||||
use_fast: bool = True,
|
use_fast: bool = True,
|
||||||
use_auth_token: Optional[Union[str, bool]] = None,
|
use_auth_token: Optional[Union[str, bool]] = None,
|
||||||
model_kwargs: Dict[str, Any] = None,
|
model_kwargs: Dict[str, Any] = None,
|
||||||
|
pipeline_class: Optional[Any] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
"""
|
"""
|
||||||
@@ -422,6 +449,14 @@ def pipeline(
|
|||||||
"""
|
"""
|
||||||
if model_kwargs is None:
|
if model_kwargs is None:
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
|
|
||||||
|
if task is None and model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Impossible to instantiate a pipeline without either a task or a model"
|
||||||
|
"being specified."
|
||||||
|
"Please provide a task class or a model"
|
||||||
|
)
|
||||||
|
|
||||||
if model is None and tokenizer is not None:
|
if model is None and tokenizer is not None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Impossible to instantiate a pipeline with tokenizer specified but not the model "
|
"Impossible to instantiate a pipeline with tokenizer specified but not the model "
|
||||||
@@ -435,9 +470,18 @@ def pipeline(
|
|||||||
"Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
|
"Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if task is None and model is not None:
|
||||||
|
if not isinstance(model, str):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Inferring the task automatically requires to check the hub with a model_id defined as a `str`."
|
||||||
|
f"{model} is not a valid model_id."
|
||||||
|
)
|
||||||
|
task = get_task(model, use_auth_token)
|
||||||
|
|
||||||
# Retrieve the task
|
# Retrieve the task
|
||||||
targeted_task, task_options = check_task(task)
|
targeted_task, task_options = check_task(task)
|
||||||
task_class = targeted_task["impl"]
|
if pipeline_class is None:
|
||||||
|
pipeline_class = targeted_task["impl"]
|
||||||
|
|
||||||
# Use default model/config/tokenizer for the task if no model is provided
|
# Use default model/config/tokenizer for the task if no model is provided
|
||||||
if model is None:
|
if model is None:
|
||||||
@@ -549,4 +593,4 @@ def pipeline(
|
|||||||
if feature_extractor is not None:
|
if feature_extractor is not None:
|
||||||
kwargs["feature_extractor"] = feature_extractor
|
kwargs["feature_extractor"] = feature_extractor
|
||||||
|
|
||||||
return task_class(model=model, framework=framework, task=task, **kwargs)
|
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
|
||||||
|
|||||||
@@ -29,8 +29,10 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
IBertConfig,
|
IBertConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
|
TextClassificationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
|
from transformers.pipelines import get_task
|
||||||
from transformers.pipelines.base import _pad
|
from transformers.pipelines.base import _pad
|
||||||
from transformers.testing_utils import is_pipeline_test, require_torch
|
from transformers.testing_utils import is_pipeline_test, require_torch
|
||||||
|
|
||||||
@@ -261,6 +263,29 @@ class CommonPipelineTest(unittest.TestCase):
|
|||||||
for output in text_classifier(dataset):
|
for output in text_classifier(dataset):
|
||||||
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
|
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_check_task_auto_inference(self):
|
||||||
|
pipe = pipeline(model="Narsil/tiny-distilbert-sequence-classification")
|
||||||
|
|
||||||
|
self.assertIsInstance(pipe, TextClassificationPipeline)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_override(self):
|
||||||
|
class MyPipeline(TextClassificationPipeline):
|
||||||
|
pass
|
||||||
|
|
||||||
|
text_classifier = pipeline(model="Narsil/tiny-distilbert-sequence-classification", pipeline_class=MyPipeline)
|
||||||
|
|
||||||
|
self.assertIsInstance(text_classifier, MyPipeline)
|
||||||
|
|
||||||
|
def test_check_task(self):
|
||||||
|
task = get_task("gpt2")
|
||||||
|
self.assertEqual(task, "text-generation")
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
# Wrong framework
|
||||||
|
get_task("espnet/siddhana_slurp_entity_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best")
|
||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
class PipelinePadTest(unittest.TestCase):
|
class PipelinePadTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user