Custom pipeline (#18079)
* Initial work * More work * Add tests for custom pipelines on the Hub * Protect import * Make the test work for TF as well * Last PyTorch specific bit * Add documentation * Style * Title in toc * Bad names! * Update docs/source/en/add_new_pipeline.mdx Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Auto stash before merge of "custom_pipeline" and "origin/custom_pipeline" * Address review comments * Address more review comments * Update src/transformers/pipelines/__init__.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -15,15 +15,21 @@
|
||||
import copy
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from unittest import skipIf
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
@@ -34,13 +40,17 @@ from transformers import (
|
||||
IBertConfig,
|
||||
RobertaConfig,
|
||||
TextClassificationPipeline,
|
||||
TFAutoModelForSequenceClassification,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.pipelines import PIPELINE_REGISTRY, get_task
|
||||
from transformers.pipelines.base import Pipeline, _pad
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
is_pipeline_test,
|
||||
is_staging_test,
|
||||
nested_simplify,
|
||||
require_scatter,
|
||||
require_tensorflow_probability,
|
||||
@@ -48,9 +58,15 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import is_tf_available, is_torch_available
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -771,7 +787,7 @@ class CustomPipeline(Pipeline):
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelineRegistryTest(unittest.TestCase):
|
||||
class CustomPipelineTest(unittest.TestCase):
|
||||
def test_warning_logs(self):
|
||||
transformers_logging.set_verbosity_debug()
|
||||
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
||||
@@ -783,25 +799,165 @@ class PipelineRegistryTest(unittest.TestCase):
|
||||
|
||||
try:
|
||||
with CaptureLogger(logger_) as cm:
|
||||
PIPELINE_REGISTRY.register_pipeline(alias, {})
|
||||
PIPELINE_REGISTRY.register_pipeline(alias, PairClassificationPipeline)
|
||||
self.assertIn(f"{alias} is already registered", cm.out)
|
||||
finally:
|
||||
# restore
|
||||
PIPELINE_REGISTRY.register_pipeline(alias, original_task)
|
||||
PIPELINE_REGISTRY.supported_tasks[alias] = original_task
|
||||
|
||||
@require_torch
|
||||
def test_register_pipeline(self):
|
||||
custom_text_classification = {
|
||||
"impl": CustomPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForSequenceClassification,),
|
||||
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
|
||||
"type": "text",
|
||||
}
|
||||
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
|
||||
PIPELINE_REGISTRY.register_pipeline(
|
||||
"custom-text-classification",
|
||||
pipeline_class=PairClassificationPipeline,
|
||||
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
|
||||
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||
default={"pt": "hf-internal-testing/tiny-random-distilbert"},
|
||||
type="text",
|
||||
)
|
||||
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
||||
|
||||
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
||||
self.assertEqual(task_def, custom_text_classification)
|
||||
self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ())
|
||||
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
|
||||
self.assertEqual(task_def["type"], "text")
|
||||
self.assertEqual(task_def["impl"], CustomPipeline)
|
||||
self.assertEqual(task_def["impl"], PairClassificationPipeline)
|
||||
self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}})
|
||||
|
||||
# Clean registry for next tests.
|
||||
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
|
||||
|
||||
def test_dynamic_pipeline(self):
|
||||
PIPELINE_REGISTRY.register_pipeline(
|
||||
"pair-classification",
|
||||
pipeline_class=PairClassificationPipeline,
|
||||
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
|
||||
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
|
||||
)
|
||||
|
||||
classifier = pipeline("pair-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Clean registry as we won't need the pipeline to be in it for the rest to work.
|
||||
del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
classifier.save_pretrained(tmp_dir)
|
||||
# checks
|
||||
self.assertDictEqual(
|
||||
classifier.model.config.custom_pipelines,
|
||||
{
|
||||
"pair-classification": {
|
||||
"impl": "custom_pipeline.PairClassificationPipeline",
|
||||
"pt": ("AutoModelForSequenceClassification",) if is_torch_available() else (),
|
||||
"tf": ("TFAutoModelForSequenceClassification",) if is_tf_available() else (),
|
||||
}
|
||||
},
|
||||
)
|
||||
# Fails if the user forget to pass along `trust_remote_code=True`
|
||||
with self.assertRaises(ValueError):
|
||||
_ = pipeline(model=tmp_dir)
|
||||
|
||||
new_classifier = pipeline(model=tmp_dir, trust_remote_code=True)
|
||||
# Using trust_remote_code=False forces the traditional pipeline tag
|
||||
old_classifier = pipeline("text-classification", model=tmp_dir, trust_remote_code=False)
|
||||
# Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
|
||||
# dynamic module
|
||||
self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
|
||||
self.assertEqual(new_classifier.task, "pair-classification")
|
||||
results = new_classifier("I hate you", second_text="I love you")
|
||||
self.assertDictEqual(
|
||||
nested_simplify(results),
|
||||
{"label": "LABEL_0", "score": 0.505, "logits": [-0.003, -0.024]},
|
||||
)
|
||||
|
||||
self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
|
||||
self.assertEqual(old_classifier.task, "text-classification")
|
||||
results = old_classifier("I hate you", text_pair="I love you")
|
||||
self.assertListEqual(
|
||||
nested_simplify(results),
|
||||
[{"label": "LABEL_0", "score": 0.505}],
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
class DynamicPipelineTester(unittest.TestCase):
|
||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "I", "love", "hate", "you"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
set_access_token(TOKEN)
|
||||
HfFolder.save_token(TOKEN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-pipeline")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub_dynamic_pipeline(self):
|
||||
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer
|
||||
|
||||
PIPELINE_REGISTRY.register_pipeline(
|
||||
"pair-classification",
|
||||
pipeline_class=PairClassificationPipeline,
|
||||
pt_model=AutoModelForSequenceClassification,
|
||||
)
|
||||
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = BertForSequenceClassification(config).eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-pipeline", use_auth_token=self._token)
|
||||
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
classifier = pipeline("pair-classification", model=model, tokenizer=tokenizer)
|
||||
|
||||
# Clean registry as we won't need the pipeline to be in it for the rest to work.
|
||||
del PIPELINE_REGISTRY.supported_tasks["pair-classification"]
|
||||
|
||||
classifier.save_pretrained(tmp_dir)
|
||||
# checks
|
||||
self.assertDictEqual(
|
||||
classifier.model.config.custom_pipelines,
|
||||
{
|
||||
"pair-classification": {
|
||||
"impl": "custom_pipeline.PairClassificationPipeline",
|
||||
"pt": ("AutoModelForSequenceClassification",),
|
||||
"tf": (),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
repo.push_to_hub()
|
||||
|
||||
# Fails if the user forget to pass along `trust_remote_code=True`
|
||||
with self.assertRaises(ValueError):
|
||||
_ = pipeline(model=f"{USER}/test-dynamic-pipeline")
|
||||
|
||||
new_classifier = pipeline(model=f"{USER}/test-dynamic-pipeline", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_classifier is from the PairClassificationPipeline class of a
|
||||
# dynamic module
|
||||
self.assertEqual(new_classifier.__class__.__name__, "PairClassificationPipeline")
|
||||
|
||||
results = classifier("I hate you", second_text="I love you")
|
||||
new_results = new_classifier("I hate you", second_text="I love you")
|
||||
self.assertDictEqual(nested_simplify(results), nested_simplify(new_results))
|
||||
|
||||
# Using trust_remote_code=False forces the traditional pipeline tag
|
||||
old_classifier = pipeline(
|
||||
"text-classification", model=f"{USER}/test-dynamic-pipeline", trust_remote_code=False
|
||||
)
|
||||
self.assertEqual(old_classifier.__class__.__name__, "TextClassificationPipeline")
|
||||
self.assertEqual(old_classifier.task, "text-classification")
|
||||
new_results = old_classifier("I hate you", text_pair="I love you")
|
||||
self.assertListEqual(
|
||||
nested_simplify([{"label": results["label"], "score": results["score"]}]), nested_simplify(new_results)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user