Rework pipeline tests (#19366)
* Rework pipeline tests * Try to fix Flax tests * Try to put it before * Use a new decorator instead * Remove ignore marker since it doesn't work * Filter pipeline tests * Woopsie * Use the fitlered list * Clean up and fake modif * Remove init * Revert fake modif
This commit is contained in:
@@ -48,13 +48,13 @@ from transformers.testing_utils import (
|
||||
USER,
|
||||
CaptureLogger,
|
||||
RequestCounter,
|
||||
is_pipeline_test,
|
||||
is_staging_test,
|
||||
nested_simplify,
|
||||
require_scatter,
|
||||
require_tensorflow_probability,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_or_tf,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import is_tf_available, is_torch_available
|
||||
@@ -307,7 +307,6 @@ class PipelineTestCaseMeta(type):
|
||||
return type.__new__(mcs, name, bases, dct)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class CommonPipelineTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_pipeline_iteration(self):
|
||||
@@ -416,7 +415,6 @@ class CommonPipelineTest(unittest.TestCase):
|
||||
self.assertEqual(len(outputs), 20)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelinePadTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_pipeline_padding(self):
|
||||
@@ -498,7 +496,6 @@ class PipelinePadTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class PipelineUtilsTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_pipeline_dataset(self):
|
||||
@@ -795,7 +792,6 @@ class CustomPipeline(Pipeline):
|
||||
return model_outputs["logits"].softmax(-1).numpy()
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class CustomPipelineTest(unittest.TestCase):
|
||||
def test_warning_logs(self):
|
||||
transformers_logging.set_verbosity_debug()
|
||||
@@ -835,6 +831,7 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
# Clean registry for next tests.
|
||||
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
|
||||
|
||||
@require_torch_or_tf
|
||||
def test_dynamic_pipeline(self):
|
||||
PIPELINE_REGISTRY.register_pipeline(
|
||||
"pair-classification",
|
||||
@@ -886,6 +883,7 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
[{"label": "LABEL_0", "score": 0.505}],
|
||||
)
|
||||
|
||||
@require_torch_or_tf
|
||||
def test_cached_pipeline_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the pipeline.
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
|
||||
Reference in New Issue
Block a user