From 49cd736a288a315d741e5c337790effa4c9fa689 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 30 Jun 2022 12:11:08 -0400 Subject: [PATCH] feat: add pipeline registry abstraction (#17905) * feat: add pipeline registry abstraction - added `PipelineRegistry` abstraction - updates `add_new_pipeline.mdx` (english docs) to reflect the api addition - migrate `check_task` and `get_supported_tasks` from transformers/pipelines/__init__.py to transformers/pipelines/base.py#PipelineRegistry.{check_task,get_supported_tasks} Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: update with upstream/main chore: Apply suggestions from sgugger's code review Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * chore: PR updates - revert src/transformers/dependency_versions_table.py from upstream/main - updates pipeline registry to use global variables Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * tests: add tests for pipeline registry Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * tests: add test for output warning. Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * chore: fmt and cleanup unused imports Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> * fix: change imports to top of the file and address comments Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/add_new_pipeline.mdx | 31 ++++++++++++- src/transformers/pipelines/__init__.py | 22 +++------- src/transformers/pipelines/base.py | 38 ++++++++++++++++ tests/pipelines/test_pipelines_common.py | 55 +++++++++++++++++++++++- 4 files changed, 125 insertions(+), 21 deletions(-) diff --git a/docs/source/en/add_new_pipeline.mdx b/docs/source/en/add_new_pipeline.mdx index 096ea423ec..1b07e651e6 100644 --- a/docs/source/en/add_new_pipeline.mdx +++ b/docs/source/en/add_new_pipeline.mdx @@ -111,8 +111,35 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes) ## Adding it to the list of supported tasks -Go to `src/transformers/pipelines/__init__.py` and fill in `SUPPORTED_TASKS` with your newly created pipeline. -If possible it should provide a default model. +To register your `new-task` to the list of supported tasks, provide the +following task template: + +```python +my_new_task = { + "impl": MyPipeline, + "tf": (), + "pt": (AutoModelForAudioClassification,) if is_torch_available() else (), + "default": {"model": {"pt": "user/awesome_model"}}, + "type": "audio", # current support type: text, audio, image, multimodal +} +``` + + + +Take a look at the `src/transformers/pipelines/__init__.py` and the dictionary `SUPPORTED_TASKS` to see how a task is defined. +If possible your custom task should provide a default model. + + + +Then add your custom task to the list of supported tasks via +`PIPELINE_REGISTRY.register_pipeline()`: + +```python +from transformers.pipelines import PIPELINE_REGISTRY + +PIPELINE_REGISTRY.register_pipeline("new-task", my_new_task) +``` + ## Adding tests diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index ab373dd38b..e0c754a85d 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -41,6 +41,7 @@ from .base import ( Pipeline, PipelineDataFormat, PipelineException, + PipelineRegistry, get_default_model_and_revision, infer_framework_load_model, ) @@ -309,14 +310,14 @@ for task, values in SUPPORTED_TASKS.items(): elif values["type"] != "multimodal": raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") +PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES) + def get_supported_tasks() -> List[str]: """ Returns a list of supported task strings. """ - supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()) - supported_tasks.sort() - return supported_tasks + return PIPELINE_REGISTRY.get_supported_tasks() def get_task(model: str, use_auth_token: Optional[str] = None) -> str: @@ -375,20 +376,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: """ - if task in TASK_ALIASES: - task = TASK_ALIASES[task] - if task in SUPPORTED_TASKS: - targeted_task = SUPPORTED_TASKS[task] - return targeted_task, None - - if task.startswith("translation"): - tokens = task.split("_") - if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to": - targeted_task = SUPPORTED_TASKS["translation"] - return targeted_task, (tokens[1], tokens[3]) - raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format") - - raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}") + return PIPELINE_REGISTRY.check_task(task) def pipeline( diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index fd49e2dea0..0e2b9ac2b8 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1087,3 +1087,41 @@ class ChunkPipeline(Pipeline): model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) return final_iterator + + +class PipelineRegistry: + def __init__(self, supported_tasks: Dict[str, Any], task_aliases: Dict[str, str]) -> None: + self.supported_tasks = supported_tasks + self.task_aliases = task_aliases + + def get_supported_tasks(self) -> List[str]: + supported_task = list(self.supported_tasks.keys()) + list(self.task_aliases.keys()) + supported_task.sort() + return supported_task + + def check_task(self, task: str) -> Tuple[Dict, Any]: + if task in self.task_aliases: + task = self.task_aliases[task] + if task in self.supported_tasks: + targeted_task = self.supported_tasks[task] + return targeted_task, None + + if task.startswith("translation"): + tokens = task.split("_") + if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to": + targeted_task = self.supported_tasks["translation"] + return targeted_task, (tokens[1], tokens[3]) + raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format") + + raise KeyError( + f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}" + ) + + def register_pipeline(self, task: str, task_impl: Dict[str, Any]) -> None: + if task in self.supported_tasks: + logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...") + + self.supported_tasks[task] = task_impl + + def to_dict(self): + return self.supported_tasks diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2d0a60805a..c75564e925 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -28,6 +28,7 @@ from transformers import ( FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, + AutoModelForSequenceClassification, AutoTokenizer, DistilBertForSequenceClassification, IBertConfig, @@ -35,9 +36,10 @@ from transformers import ( TextClassificationPipeline, pipeline, ) -from transformers.pipelines import get_task -from transformers.pipelines.base import _pad +from transformers.pipelines import PIPELINE_REGISTRY, get_task +from transformers.pipelines.base import Pipeline, _pad from transformers.testing_utils import ( + CaptureLogger, is_pipeline_test, nested_simplify, require_scatter, @@ -46,6 +48,7 @@ from transformers.testing_utils import ( require_torch, slow, ) +from transformers.utils import logging as transformers_logging logger = logging.getLogger(__name__) @@ -746,3 +749,51 @@ class PipelineUtilsTest(unittest.TestCase): models_are_equal = False return models_are_equal + + +class CustomPipeline(Pipeline): + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + if "maybe_arg" in kwargs: + preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] + return preprocess_kwargs, {}, {} + + def preprocess(self, text, maybe_arg=2): + input_ids = self.tokenizer(text, return_tensors="pt") + return input_ids + + def _forward(self, model_inputs): + outputs = self.model(**model_inputs) + return outputs + + def postprocess(self, model_outputs): + return model_outputs["logits"].softmax(-1).numpy() + + +@is_pipeline_test +class PipelineRegistryTest(unittest.TestCase): + def test_warning_logs(self): + transformers_logging.set_verbosity_debug() + logger_ = transformers_logging.get_logger("transformers.pipelines.base") + + alias = "text-classification" + with CaptureLogger(logger_) as cm: + PIPELINE_REGISTRY.register_pipeline(alias, {}) + self.assertIn(f"{alias} is already registered", cm.out) + + @require_torch + def test_register_pipeline(self): + custom_text_classification = { + "impl": CustomPipeline, + "tf": (), + "pt": (AutoModelForSequenceClassification,), + "default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}}, + "type": "text", + } + PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification) + assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks() + + task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification") + self.assertEqual(task_def, custom_text_classification) + self.assertEqual(task_def["type"], "text") + self.assertEqual(task_def["impl"], CustomPipeline)