feat: add pipeline registry abstraction (#17905)
* feat: add pipeline registry abstraction
- added `PipelineRegistry` abstraction
- updates `add_new_pipeline.mdx` (english docs) to reflect the api addition
- migrate `check_task` and `get_supported_tasks` from
transformers/pipelines/__init__.py to
transformers/pipelines/base.py#PipelineRegistry.{check_task,get_supported_tasks}
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* fix: update with upstream/main
chore: Apply suggestions from sgugger's code review
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* chore: PR updates
- revert src/transformers/dependency_versions_table.py from upstream/main
- updates pipeline registry to use global variables
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* tests: add tests for pipeline registry
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* tests: add test for output warning.
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* chore: fmt and cleanup unused imports
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* fix: change imports to top of the file and address comments
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -111,8 +111,35 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes)
|
||||
|
||||
## Adding it to the list of supported tasks
|
||||
|
||||
Go to `src/transformers/pipelines/__init__.py` and fill in `SUPPORTED_TASKS` with your newly created pipeline.
|
||||
If possible it should provide a default model.
|
||||
To register your `new-task` to the list of supported tasks, provide the
|
||||
following task template:
|
||||
|
||||
```python
|
||||
my_new_task = {
|
||||
"impl": MyPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "user/awesome_model"}},
|
||||
"type": "audio", # current support type: text, audio, image, multimodal
|
||||
}
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Take a look at the `src/transformers/pipelines/__init__.py` and the dictionary `SUPPORTED_TASKS` to see how a task is defined.
|
||||
If possible your custom task should provide a default model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Then add your custom task to the list of supported tasks via
|
||||
`PIPELINE_REGISTRY.register_pipeline()`:
|
||||
|
||||
```python
|
||||
from transformers.pipelines import PIPELINE_REGISTRY
|
||||
|
||||
PIPELINE_REGISTRY.register_pipeline("new-task", my_new_task)
|
||||
```
|
||||
|
||||
|
||||
## Adding tests
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from .base import (
|
||||
Pipeline,
|
||||
PipelineDataFormat,
|
||||
PipelineException,
|
||||
PipelineRegistry,
|
||||
get_default_model_and_revision,
|
||||
infer_framework_load_model,
|
||||
)
|
||||
@@ -309,14 +310,14 @@ for task, values in SUPPORTED_TASKS.items():
|
||||
elif values["type"] != "multimodal":
|
||||
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
|
||||
|
||||
PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
|
||||
|
||||
|
||||
def get_supported_tasks() -> List[str]:
|
||||
"""
|
||||
Returns a list of supported task strings.
|
||||
"""
|
||||
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys())
|
||||
supported_tasks.sort()
|
||||
return supported_tasks
|
||||
return PIPELINE_REGISTRY.get_supported_tasks()
|
||||
|
||||
|
||||
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
||||
@@ -375,20 +376,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
|
||||
|
||||
|
||||
"""
|
||||
if task in TASK_ALIASES:
|
||||
task = TASK_ALIASES[task]
|
||||
if task in SUPPORTED_TASKS:
|
||||
targeted_task = SUPPORTED_TASKS[task]
|
||||
return targeted_task, None
|
||||
|
||||
if task.startswith("translation"):
|
||||
tokens = task.split("_")
|
||||
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
|
||||
targeted_task = SUPPORTED_TASKS["translation"]
|
||||
return targeted_task, (tokens[1], tokens[3])
|
||||
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
|
||||
|
||||
raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")
|
||||
return PIPELINE_REGISTRY.check_task(task)
|
||||
|
||||
|
||||
def pipeline(
|
||||
|
||||
@@ -1087,3 +1087,41 @@ class ChunkPipeline(Pipeline):
|
||||
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
|
||||
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
|
||||
return final_iterator
|
||||
|
||||
|
||||
class PipelineRegistry:
|
||||
def __init__(self, supported_tasks: Dict[str, Any], task_aliases: Dict[str, str]) -> None:
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task_aliases = task_aliases
|
||||
|
||||
def get_supported_tasks(self) -> List[str]:
|
||||
supported_task = list(self.supported_tasks.keys()) + list(self.task_aliases.keys())
|
||||
supported_task.sort()
|
||||
return supported_task
|
||||
|
||||
def check_task(self, task: str) -> Tuple[Dict, Any]:
|
||||
if task in self.task_aliases:
|
||||
task = self.task_aliases[task]
|
||||
if task in self.supported_tasks:
|
||||
targeted_task = self.supported_tasks[task]
|
||||
return targeted_task, None
|
||||
|
||||
if task.startswith("translation"):
|
||||
tokens = task.split("_")
|
||||
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
|
||||
targeted_task = self.supported_tasks["translation"]
|
||||
return targeted_task, (tokens[1], tokens[3])
|
||||
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
|
||||
|
||||
raise KeyError(
|
||||
f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
|
||||
)
|
||||
|
||||
def register_pipeline(self, task: str, task_impl: Dict[str, Any]) -> None:
|
||||
if task in self.supported_tasks:
|
||||
logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
|
||||
|
||||
self.supported_tasks[task] = task_impl
|
||||
|
||||
def to_dict(self):
|
||||
return self.supported_tasks
|
||||
|
||||
@@ -28,6 +28,7 @@ from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
IBertConfig,
|
||||
@@ -35,9 +36,10 @@ from transformers import (
|
||||
TextClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import get_task
|
||||
from transformers.pipelines.base import _pad
|
||||
from transformers.pipelines import PIPELINE_REGISTRY, get_task
|
||||
from transformers.pipelines.base import Pipeline, _pad
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_scatter,
|
||||
@@ -46,6 +48,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -746,3 +749,51 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
class CustomPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "maybe_arg" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, text, maybe_arg=2):
|
||||
input_ids = self.tokenizer(text, return_tensors="pt")
|
||||
return input_ids
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model(**model_inputs)
|
||||
return outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs["logits"].softmax(-1).numpy()
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelineRegistryTest(unittest.TestCase):
|
||||
def test_warning_logs(self):
|
||||
transformers_logging.set_verbosity_debug()
|
||||
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
||||
|
||||
alias = "text-classification"
|
||||
with CaptureLogger(logger_) as cm:
|
||||
PIPELINE_REGISTRY.register_pipeline(alias, {})
|
||||
self.assertIn(f"{alias} is already registered", cm.out)
|
||||
|
||||
@require_torch
|
||||
def test_register_pipeline(self):
|
||||
custom_text_classification = {
|
||||
"impl": CustomPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForSequenceClassification,),
|
||||
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
|
||||
"type": "text",
|
||||
}
|
||||
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
|
||||
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
||||
|
||||
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
||||
self.assertEqual(task_def, custom_text_classification)
|
||||
self.assertEqual(task_def["type"], "text")
|
||||
self.assertEqual(task_def["impl"], CustomPipeline)
|
||||
|
||||
Reference in New Issue
Block a user