feat: add pipeline registry abstraction (#17905)

* feat: add pipeline registry abstraction

- added `PipelineRegistry` abstraction
- updates `add_new_pipeline.mdx` (english docs) to reflect the api addition
- migrate `check_task` and `get_supported_tasks` from
  transformers/pipelines/__init__.py to
  transformers/pipelines/base.py#PipelineRegistry.{check_task,get_supported_tasks}

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: update with upstream/main

chore: Apply suggestions from sgugger's code review

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* chore: PR updates

- revert src/transformers/dependency_versions_table.py from upstream/main
- updates pipeline registry to use global variables

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* tests: add tests for pipeline registry

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* tests: add test for output warning.

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: fmt and cleanup unused imports

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: change imports to top of the file and address comments

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2022-06-30 12:11:08 -04:00
committed by GitHub
parent 9cb7cef285
commit 49cd736a28
4 changed files with 125 additions and 21 deletions

View File

@@ -28,6 +28,7 @@ from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoModelForSequenceClassification,
AutoTokenizer,
DistilBertForSequenceClassification,
IBertConfig,
@@ -35,9 +36,10 @@ from transformers import (
TextClassificationPipeline,
pipeline,
)
from transformers.pipelines import get_task
from transformers.pipelines.base import _pad
from transformers.pipelines import PIPELINE_REGISTRY, get_task
from transformers.pipelines.base import Pipeline, _pad
from transformers.testing_utils import (
CaptureLogger,
is_pipeline_test,
nested_simplify,
require_scatter,
@@ -46,6 +48,7 @@ from transformers.testing_utils import (
require_torch,
slow,
)
from transformers.utils import logging as transformers_logging
logger = logging.getLogger(__name__)
@@ -746,3 +749,51 @@ class PipelineUtilsTest(unittest.TestCase):
models_are_equal = False
return models_are_equal
class CustomPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, maybe_arg=2):
input_ids = self.tokenizer(text, return_tensors="pt")
return input_ids
def _forward(self, model_inputs):
outputs = self.model(**model_inputs)
return outputs
def postprocess(self, model_outputs):
return model_outputs["logits"].softmax(-1).numpy()
@is_pipeline_test
class PipelineRegistryTest(unittest.TestCase):
def test_warning_logs(self):
transformers_logging.set_verbosity_debug()
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
alias = "text-classification"
with CaptureLogger(logger_) as cm:
PIPELINE_REGISTRY.register_pipeline(alias, {})
self.assertIn(f"{alias} is already registered", cm.out)
@require_torch
def test_register_pipeline(self):
custom_text_classification = {
"impl": CustomPipeline,
"tf": (),
"pt": (AutoModelForSequenceClassification,),
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
"type": "text",
}
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
self.assertEqual(task_def, custom_text_classification)
self.assertEqual(task_def["type"], "text")
self.assertEqual(task_def["impl"], CustomPipeline)