From 6f0723a9be8e2440f5d609e85da91baf14928e3c Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 1 Jul 2022 20:44:27 +0200 Subject: [PATCH] Restore original task in test_warning_logs (#17985) Co-authored-by: ydshieh --- tests/pipelines/test_pipelines_common.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index c75564e925..9c3a94c64c 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -777,9 +777,17 @@ class PipelineRegistryTest(unittest.TestCase): logger_ = transformers_logging.get_logger("transformers.pipelines.base") alias = "text-classification" - with CaptureLogger(logger_) as cm: - PIPELINE_REGISTRY.register_pipeline(alias, {}) - self.assertIn(f"{alias} is already registered", cm.out) + # Get the original task, so we can restore it at the end. + # (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail) + original_task, original_task_options = PIPELINE_REGISTRY.check_task(alias) + + try: + with CaptureLogger(logger_) as cm: + PIPELINE_REGISTRY.register_pipeline(alias, {}) + self.assertIn(f"{alias} is already registered", cm.out) + finally: + # restore + PIPELINE_REGISTRY.register_pipeline(alias, original_task) @require_torch def test_register_pipeline(self):