Restore original task in test_warning_logs (#17985)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -777,9 +777,17 @@ class PipelineRegistryTest(unittest.TestCase):
|
|||||||
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
|
||||||
|
|
||||||
alias = "text-classification"
|
alias = "text-classification"
|
||||||
|
# Get the original task, so we can restore it at the end.
|
||||||
|
# (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail)
|
||||||
|
original_task, original_task_options = PIPELINE_REGISTRY.check_task(alias)
|
||||||
|
|
||||||
|
try:
|
||||||
with CaptureLogger(logger_) as cm:
|
with CaptureLogger(logger_) as cm:
|
||||||
PIPELINE_REGISTRY.register_pipeline(alias, {})
|
PIPELINE_REGISTRY.register_pipeline(alias, {})
|
||||||
self.assertIn(f"{alias} is already registered", cm.out)
|
self.assertIn(f"{alias} is already registered", cm.out)
|
||||||
|
finally:
|
||||||
|
# restore
|
||||||
|
PIPELINE_REGISTRY.register_pipeline(alias, original_task)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_register_pipeline(self):
|
def test_register_pipeline(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user