feat: add pipeline registry abstraction (#17905)
* feat: add pipeline registry abstraction
- added `PipelineRegistry` abstraction
- updates `add_new_pipeline.mdx` (english docs) to reflect the api addition
- migrate `check_task` and `get_supported_tasks` from
transformers/pipelines/__init__.py to
transformers/pipelines/base.py#PipelineRegistry.{check_task,get_supported_tasks}
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* fix: update with upstream/main
chore: Apply suggestions from sgugger's code review
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* chore: PR updates
- revert src/transformers/dependency_versions_table.py from upstream/main
- updates pipeline registry to use global variables
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* tests: add tests for pipeline registry
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* tests: add test for output warning.
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* chore: fmt and cleanup unused imports
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* fix: change imports to top of the file and address comments
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -111,8 +111,35 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes)
|
|||||||
|
|
||||||
## Adding it to the list of supported tasks
|
## Adding it to the list of supported tasks
|
||||||
|
|
||||||
Go to `src/transformers/pipelines/__init__.py` and fill in `SUPPORTED_TASKS` with your newly created pipeline.
|
To register your `new-task` to the list of supported tasks, provide the
|
||||||
If possible it should provide a default model.
|
following task template:
|
||||||
|
|
||||||
|
```python
|
||||||
|
my_new_task = {
|
||||||
|
"impl": MyPipeline,
|
||||||
|
"tf": (),
|
||||||
|
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
||||||
|
"default": {"model": {"pt": "user/awesome_model"}},
|
||||||
|
"type": "audio", # current support type: text, audio, image, multimodal
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Take a look at the `src/transformers/pipelines/__init__.py` and the dictionary `SUPPORTED_TASKS` to see how a task is defined.
|
||||||
|
If possible your custom task should provide a default model.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Then add your custom task to the list of supported tasks via
|
||||||
|
`PIPELINE_REGISTRY.register_pipeline()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers.pipelines import PIPELINE_REGISTRY
|
||||||
|
|
||||||
|
PIPELINE_REGISTRY.register_pipeline("new-task", my_new_task)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Adding tests
|
## Adding tests
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from .base import (
|
|||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineDataFormat,
|
PipelineDataFormat,
|
||||||
PipelineException,
|
PipelineException,
|
||||||
|
PipelineRegistry,
|
||||||
get_default_model_and_revision,
|
get_default_model_and_revision,
|
||||||
infer_framework_load_model,
|
infer_framework_load_model,
|
||||||
)
|
)
|
||||||
@@ -309,14 +310,14 @@ for task, values in SUPPORTED_TASKS.items():
|
|||||||
elif values["type"] != "multimodal":
|
elif values["type"] != "multimodal":
|
||||||
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
|
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
|
||||||
|
|
||||||
|
PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
|
||||||
|
|
||||||
|
|
||||||
def get_supported_tasks() -> List[str]:
|
def get_supported_tasks() -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns a list of supported task strings.
|
Returns a list of supported task strings.
|
||||||
"""
|
"""
|
||||||
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
|
return PIPELINE_REGISTRY.get_supported_tasks()
|
||||||
supported_tasks.sort()
|
|
||||||
return supported_tasks
|
|
||||||
|
|
||||||
|
|
||||||
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
||||||
@@ -375,20 +376,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if task in TASK_ALIASES:
|
return PIPELINE_REGISTRY.check_task(task)
|
||||||
task = TASK_ALIASES[task]
|
|
||||||
if task in SUPPORTED_TASKS:
|
|
||||||
targeted_task = SUPPORTED_TASKS[task]
|
|
||||||
return targeted_task, None
|
|
||||||
|
|
||||||
if task.startswith("translation"):
|
|
||||||
tokens = task.split("_")
|
|
||||||
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
|
|
||||||
targeted_task = SUPPORTED_TASKS["translation"]
|
|
||||||
return targeted_task, (tokens[1], tokens[3])
|
|
||||||
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
|
|
||||||
|
|
||||||
raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")
|
|
||||||
|
|
||||||
|
|
||||||
def pipeline(
|
def pipeline(
|
||||||
|
|||||||
@@ -1087,3 +1087,41 @@ class ChunkPipeline(Pipeline):
|
|||||||
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
|
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
|
||||||
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
||||||
return final_iterator
|
return final_iterator
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRegistry:
|
||||||
|
def __init__(self, supported_tasks: Dict[str, Any], task_aliases: Dict[str, str]) -> None:
|
||||||
|
self.supported_tasks = supported_tasks
|
||||||
|
self.task_aliases = task_aliases
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> List[str]:
|
||||||
|
supported_task = list(self.supported_tasks.keys()) + list(self.task_aliases.keys())
|
||||||
|
supported_task.sort()
|
||||||
|
return supported_task
|
||||||
|
|
||||||
|
def check_task(self, task: str) -> Tuple[Dict, Any]:
|
||||||
|
if task in self.task_aliases:
|
||||||
|
task = self.task_aliases[task]
|
||||||
|
if task in self.supported_tasks:
|
||||||
|
targeted_task = self.supported_tasks[task]
|
||||||
|
return targeted_task, None
|
||||||
|
|
||||||
|
if task.startswith("translation"):
|
||||||
|
tokens = task.split("_")
|
||||||
|
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
|
||||||
|
targeted_task = self.supported_tasks["translation"]
|
||||||
|
return targeted_task, (tokens[1], tokens[3])
|
||||||
|
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
|
||||||
|
|
||||||
|
raise KeyError(
|
||||||
|
f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_pipeline(self, task: str, task_impl: Dict[str, Any]) -> None:
|
||||||
|
if task in self.supported_tasks:
|
||||||
|
logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
|
||||||
|
|
||||||
|
self.supported_tasks[task] = task_impl
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return self.supported_tasks
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers import (
|
|||||||
FEATURE_EXTRACTOR_MAPPING,
|
FEATURE_EXTRACTOR_MAPPING,
|
||||||
TOKENIZER_MAPPING,
|
TOKENIZER_MAPPING,
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DistilBertForSequenceClassification,
|
DistilBertForSequenceClassification,
|
||||||
IBertConfig,
|
IBertConfig,
|
||||||
@@ -35,9 +36,10 @@ from transformers import (
|
|||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
)
|
)
|
||||||
from transformers.pipelines import get_task
|
from transformers.pipelines import PIPELINE_REGISTRY, get_task
|
||||||
from transformers.pipelines.base import _pad
|
from transformers.pipelines.base import Pipeline, _pad
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
|
CaptureLogger,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_scatter,
|
require_scatter,
|
||||||
@@ -46,6 +48,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -746,3 +749,51 @@ class PipelineUtilsTest(unittest.TestCase):
|
|||||||
models_are_equal = False
|
models_are_equal = False
|
||||||
|
|
||||||
return models_are_equal
|
return models_are_equal
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPipeline(Pipeline):
|
||||||
|
def _sanitize_parameters(self, **kwargs):
|
||||||
|
preprocess_kwargs = {}
|
||||||
|
if "maybe_arg" in kwargs:
|
||||||
|
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||||
|
return preprocess_kwargs, {}, {}
|
||||||
|
|
||||||
|
def preprocess(self, text, maybe_arg=2):
|
||||||
|
input_ids = self.tokenizer(text, return_tensors="pt")
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def _forward(self, model_inputs):
|
||||||
|
outputs = self.model(**model_inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def postprocess(self, model_outputs):
|
||||||
|
return model_outputs["logits"].softmax(-1).numpy()
|
||||||
|
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
class PipelineRegistryTest(unittest.TestCase):
|
||||||
|
def test_warning_logs(self):
|
||||||
|
transformers_logging.set_verbosity_debug()
|
||||||
|
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
||||||
|
|
||||||
|
alias = "text-classification"
|
||||||
|
with CaptureLogger(logger_) as cm:
|
||||||
|
PIPELINE_REGISTRY.register_pipeline(alias, {})
|
||||||
|
self.assertIn(f"{alias} is already registered", cm.out)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_register_pipeline(self):
|
||||||
|
custom_text_classification = {
|
||||||
|
"impl": CustomPipeline,
|
||||||
|
"tf": (),
|
||||||
|
"pt": (AutoModelForSequenceClassification,),
|
||||||
|
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
|
||||||
|
"type": "text",
|
||||||
|
}
|
||||||
|
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
|
||||||
|
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
||||||
|
|
||||||
|
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
||||||
|
self.assertEqual(task_def, custom_text_classification)
|
||||||
|
self.assertEqual(task_def["type"], "text")
|
||||||
|
self.assertEqual(task_def["impl"], CustomPipeline)
|
||||||
|
|||||||
Reference in New Issue
Block a user