Adding some quality of life for pipeline function. (#14322)
* Adding some quality of life for `pipeline` function. * Update docs/source/main_classes/pipelines.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/__init__.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Improve the tests. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -45,7 +45,7 @@ The pipeline abstraction
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other
|
||||
pipeline but requires an additional argument which is the `task`.
|
||||
pipeline but can provide additional quality of life.
|
||||
|
||||
Simple call on one item:
|
||||
|
||||
@@ -55,6 +55,15 @@ Simple call on one item:
|
||||
>>> pipe("This restaurant is awesome")
|
||||
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
|
||||
|
||||
If you want to use a specific model from the `hub <https://huggingface.co>`__ you can ignore the task if the model on
|
||||
the hub already defines it:
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> pipe = pipeline(model="roberta-large-mnli")
|
||||
>>> pipe("This restaurant is awesome")
|
||||
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
|
||||
|
||||
To call a pipeline on many items, you can either call with a `list`.
|
||||
|
||||
.. code-block::
|
||||
@@ -226,6 +235,32 @@ For users, a rule of thumb is:
|
||||
- The larger the GPU the more likely batching is going to be more interesting
|
||||
- As soon as you enable batching, make sure you can handle OOMs nicely.
|
||||
|
||||
Pipeline custom code
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
If you want to override a specific pipeline.
|
||||
|
||||
Don't hesitate to create an issue for your task at hand, the goal of the pipeline is to be easy to use and support most
|
||||
cases, so :obj:`transformers` could maybe support your use case.
|
||||
|
||||
|
||||
If you want to try simply you can:
|
||||
|
||||
- Subclass your pipeline of choice
|
||||
|
||||
.. code-block::
|
||||
|
||||
class MyPipeline(TextClassificationPipeline):
|
||||
def postprocess(...):
|
||||
...
|
||||
scores = scores * 100
|
||||
...
|
||||
|
||||
my_pipeline = MyPipeline(model=model, tokenizer=tokenizer, ...)
|
||||
# or if you use `pipeline` function, then:
|
||||
my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
|
||||
|
||||
That should enable you to do all the custom code you want.
|
||||
|
||||
|
||||
Implementing a pipeline
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
import io
|
||||
import json
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
@@ -21,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
from ..file_utils import is_tf_available, is_torch_available
|
||||
from ..file_utils import http_get, is_tf_available, is_torch_available
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
@@ -248,6 +251,29 @@ SUPPORTED_TASKS = {
|
||||
}
|
||||
|
||||
|
||||
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
||||
tmp = io.BytesIO()
|
||||
headers = {}
|
||||
if use_auth_token:
|
||||
headers["Authorization"] = f"Bearer {use_auth_token}"
|
||||
|
||||
try:
|
||||
http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers)
|
||||
tmp.seek(0)
|
||||
body = tmp.read()
|
||||
data = json.loads(body)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
|
||||
if "pipeline_tag" not in data:
|
||||
raise RuntimeError(
|
||||
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
||||
)
|
||||
if data.get("library_name", "transformers") != "transformers":
|
||||
raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers")
|
||||
task = data["pipeline_tag"]
|
||||
return task
|
||||
|
||||
|
||||
def check_task(task: str) -> Tuple[Dict, Any]:
|
||||
"""
|
||||
Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
|
||||
@@ -299,7 +325,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
|
||||
|
||||
|
||||
def pipeline(
|
||||
task: str,
|
||||
task: str = None,
|
||||
model: Optional = None,
|
||||
config: Optional[Union[str, PretrainedConfig]] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
||||
@@ -309,6 +335,7 @@ def pipeline(
|
||||
use_fast: bool = True,
|
||||
use_auth_token: Optional[Union[str, bool]] = None,
|
||||
model_kwargs: Dict[str, Any] = None,
|
||||
pipeline_class: Optional[Any] = None,
|
||||
**kwargs
|
||||
) -> Pipeline:
|
||||
"""
|
||||
@@ -422,6 +449,14 @@ def pipeline(
|
||||
"""
|
||||
if model_kwargs is None:
|
||||
model_kwargs = {}
|
||||
|
||||
if task is None and model is None:
|
||||
raise RuntimeError(
|
||||
"Impossible to instantiate a pipeline without either a task or a model"
|
||||
"being specified."
|
||||
"Please provide a task class or a model"
|
||||
)
|
||||
|
||||
if model is None and tokenizer is not None:
|
||||
raise RuntimeError(
|
||||
"Impossible to instantiate a pipeline with tokenizer specified but not the model "
|
||||
@@ -435,9 +470,18 @@ def pipeline(
|
||||
"Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
|
||||
)
|
||||
|
||||
if task is None and model is not None:
|
||||
if not isinstance(model, str):
|
||||
raise RuntimeError(
|
||||
"Inferring the task automatically requires to check the hub with a model_id defined as a `str`."
|
||||
f"{model} is not a valid model_id."
|
||||
)
|
||||
task = get_task(model, use_auth_token)
|
||||
|
||||
# Retrieve the task
|
||||
targeted_task, task_options = check_task(task)
|
||||
task_class = targeted_task["impl"]
|
||||
if pipeline_class is None:
|
||||
pipeline_class = targeted_task["impl"]
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
@@ -549,4 +593,4 @@ def pipeline(
|
||||
if feature_extractor is not None:
|
||||
kwargs["feature_extractor"] = feature_extractor
|
||||
|
||||
return task_class(model=model, framework=framework, task=task, **kwargs)
|
||||
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
|
||||
|
||||
@@ -29,8 +29,10 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
IBertConfig,
|
||||
RobertaConfig,
|
||||
TextClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import get_task
|
||||
from transformers.pipelines.base import _pad
|
||||
from transformers.testing_utils import is_pipeline_test, require_torch
|
||||
|
||||
@@ -261,6 +263,29 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
for output in text_classifier(dataset):
|
||||
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
|
||||
|
||||
@require_torch
|
||||
def test_check_task_auto_inference(self):
|
||||
pipe = pipeline(model="Narsil/tiny-distilbert-sequence-classification")
|
||||
|
||||
self.assertIsInstance(pipe, TextClassificationPipeline)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_override(self):
|
||||
class MyPipeline(TextClassificationPipeline):
|
||||
pass
|
||||
|
||||
text_classifier = pipeline(model="Narsil/tiny-distilbert-sequence-classification", pipeline_class=MyPipeline)
|
||||
|
||||
self.assertIsInstance(text_classifier, MyPipeline)
|
||||
|
||||
def test_check_task(self):
|
||||
task = get_task("gpt2")
|
||||
self.assertEqual(task, "text-generation")
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
# Wrong framework
|
||||
get_task("espnet/siddhana_slurp_entity_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best")
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelinePadTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user