Rework pipeline tests (#19366)

* Rework pipeline tests

* Try to fix Flax tests

* Try to put it before

* Use a new decorator instead

* Remove ignore marker since it doesn't work

* Filter pipeline tests

* Woopsie

* Use the fitlered list

* Clean up and fake modif

* Remove init

* Revert fake modif
This commit is contained in:
Sylvain Gugger
2022-10-07 18:01:58 -04:00
committed by GitHub
parent 983451a13e
commit 9ac586b3c8
27 changed files with 95 additions and 149 deletions

View File

@@ -48,13 +48,13 @@ from transformers.testing_utils import (
USER,
CaptureLogger,
RequestCounter,
is_pipeline_test,
is_staging_test,
nested_simplify,
require_scatter,
require_tensorflow_probability,
require_tf,
require_torch,
require_torch_or_tf,
slow,
)
from transformers.utils import is_tf_available, is_torch_available
@@ -307,7 +307,6 @@ class PipelineTestCaseMeta(type):
return type.__new__(mcs, name, bases, dct)
@is_pipeline_test
class CommonPipelineTest(unittest.TestCase):
@require_torch
def test_pipeline_iteration(self):
@@ -416,7 +415,6 @@ class CommonPipelineTest(unittest.TestCase):
self.assertEqual(len(outputs), 20)
@is_pipeline_test
class PipelinePadTest(unittest.TestCase):
@require_torch
def test_pipeline_padding(self):
@@ -498,7 +496,6 @@ class PipelinePadTest(unittest.TestCase):
)
@is_pipeline_test
class PipelineUtilsTest(unittest.TestCase):
@require_torch
def test_pipeline_dataset(self):
@@ -795,7 +792,6 @@ class CustomPipeline(Pipeline):
return model_outputs["logits"].softmax(-1).numpy()
@is_pipeline_test
class CustomPipelineTest(unittest.TestCase):
def test_warning_logs(self):
transformers_logging.set_verbosity_debug()
@@ -835,6 +831,7 @@ class CustomPipelineTest(unittest.TestCase):
# Clean registry for next tests.
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
@require_torch_or_tf
def test_dynamic_pipeline(self):
PIPELINE_REGISTRY.register_pipeline(
"pair-classification",
@@ -886,6 +883,7 @@ class CustomPipelineTest(unittest.TestCase):
[{"label": "LABEL_0", "score": 0.505}],
)
@require_torch_or_tf
def test_cached_pipeline_has_minimum_calls_to_head(self):
# Make sure we have cached the pipeline.
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")