feat: add pipeline registry abstraction (#17905)
* feat: add pipeline registry abstraction
- added `PipelineRegistry` abstraction
- updates `add_new_pipeline.mdx` (english docs) to reflect the api addition
- migrate `check_task` and `get_supported_tasks` from
transformers/pipelines/__init__.py to
transformers/pipelines/base.py#PipelineRegistry.{check_task,get_supported_tasks}
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* fix: update with upstream/main
chore: Apply suggestions from sgugger's code review
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* chore: PR updates
- revert src/transformers/dependency_versions_table.py from upstream/main
- updates pipeline registry to use global variables
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* tests: add tests for pipeline registry
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* tests: add test for output warning.
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* chore: fmt and cleanup unused imports
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
* fix: change imports to top of the file and address comments
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
IBertConfig,
|
||||
@@ -35,9 +36,10 @@ from transformers import (
|
||||
TextClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import get_task
|
||||
from transformers.pipelines.base import _pad
|
||||
from transformers.pipelines import PIPELINE_REGISTRY, get_task
|
||||
from transformers.pipelines.base import Pipeline, _pad
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_pipeline_test,
|
||||
nested_simplify,
|
||||
require_scatter,
|
||||
@@ -46,6 +48,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -746,3 +749,51 @@ class PipelineUtilsTest(unittest.TestCase):
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
class CustomPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "maybe_arg" in kwargs:
|
||||
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, text, maybe_arg=2):
|
||||
input_ids = self.tokenizer(text, return_tensors="pt")
|
||||
return input_ids
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
outputs = self.model(**model_inputs)
|
||||
return outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
return model_outputs["logits"].softmax(-1).numpy()
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelineRegistryTest(unittest.TestCase):
|
||||
def test_warning_logs(self):
|
||||
transformers_logging.set_verbosity_debug()
|
||||
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
||||
|
||||
alias = "text-classification"
|
||||
with CaptureLogger(logger_) as cm:
|
||||
PIPELINE_REGISTRY.register_pipeline(alias, {})
|
||||
self.assertIn(f"{alias} is already registered", cm.out)
|
||||
|
||||
@require_torch
|
||||
def test_register_pipeline(self):
|
||||
custom_text_classification = {
|
||||
"impl": CustomPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForSequenceClassification,),
|
||||
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
|
||||
"type": "text",
|
||||
}
|
||||
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
|
||||
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
||||
|
||||
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
||||
self.assertEqual(task_def, custom_text_classification)
|
||||
self.assertEqual(task_def["type"], "text")
|
||||
self.assertEqual(task_def["impl"], CustomPipeline)
|
||||
|
||||
Reference in New Issue
Block a user