Fix pipeline tests (#18487)
* Fix pipeline tests * Make sure all pipelines tests run with init changes
This commit is contained in:
@@ -606,6 +606,7 @@ def pipeline(
|
|||||||
|
|
||||||
# Retrieve the task
|
# Retrieve the task
|
||||||
if task in custom_tasks:
|
if task in custom_tasks:
|
||||||
|
normalized_task = task
|
||||||
targeted_task, task_options = clean_custom_task(custom_tasks[task])
|
targeted_task, task_options = clean_custom_task(custom_tasks[task])
|
||||||
if pipeline_class is None:
|
if pipeline_class is None:
|
||||||
if not trust_remote_code:
|
if not trust_remote_code:
|
||||||
|
|||||||
@@ -795,7 +795,7 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
alias = "text-classification"
|
alias = "text-classification"
|
||||||
# Get the original task, so we can restore it at the end.
|
# Get the original task, so we can restore it at the end.
|
||||||
# (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail)
|
# (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail)
|
||||||
original_task, original_task_options = PIPELINE_REGISTRY.check_task(alias)
|
_, original_task, _ = PIPELINE_REGISTRY.check_task(alias)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with CaptureLogger(logger_) as cm:
|
with CaptureLogger(logger_) as cm:
|
||||||
@@ -816,7 +816,7 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
|
||||||
|
|
||||||
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
_, task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
|
||||||
self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ())
|
self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ())
|
||||||
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
|
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
|
||||||
self.assertEqual(task_def["type"], "text")
|
self.assertEqual(task_def["type"], "text")
|
||||||
|
|||||||
@@ -377,6 +377,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
|
|||||||
],
|
],
|
||||||
"optimization.py": "optimization/test_optimization.py",
|
"optimization.py": "optimization/test_optimization.py",
|
||||||
"optimization_tf.py": "optimization/test_optimization_tf.py",
|
"optimization_tf.py": "optimization/test_optimization_tf.py",
|
||||||
|
"pipelines/__init__.py": "pipelines/test_pipelines_*.py",
|
||||||
"pipelines/base.py": "pipelines/test_pipelines_*.py",
|
"pipelines/base.py": "pipelines/test_pipelines_*.py",
|
||||||
"pipelines/text2text_generation.py": [
|
"pipelines/text2text_generation.py": [
|
||||||
"pipelines/test_pipelines_text2text_generation.py",
|
"pipelines/test_pipelines_text2text_generation.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user