Rework pipeline tests (#19366)
* Rework pipeline tests * Try to fix Flax tests * Try to put it before * Use a new decorator instead * Remove ignore marker since it doesn't work * Filter pipeline tests * Woopsie * Use the fitlered list * Clean up and fake modif * Remove init * Revert fake modif
This commit is contained in:
@@ -133,7 +133,6 @@ _run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=Fa
|
||||
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
|
||||
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
|
||||
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
|
||||
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
|
||||
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
|
||||
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
|
||||
|
||||
@@ -176,25 +175,6 @@ def is_pt_flax_cross_test(test_case):
|
||||
return pytest.mark.is_pt_flax_cross_test()(test_case)
|
||||
|
||||
|
||||
def is_pipeline_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a pipeline test.
|
||||
|
||||
Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TESTS environment variable
|
||||
to a truthy value and selecting the is_pipeline_test pytest mark.
|
||||
|
||||
"""
|
||||
if not _run_pipeline_tests:
|
||||
return unittest.skip("test is pipeline test")(test_case)
|
||||
else:
|
||||
try:
|
||||
import pytest # We don't need a hard dependency on pytest in the main library
|
||||
except ImportError:
|
||||
return test_case
|
||||
else:
|
||||
return pytest.mark.is_pipeline_test()(test_case)
|
||||
|
||||
|
||||
def is_staging_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a staging test.
|
||||
@@ -309,6 +289,18 @@ def require_torch(test_case):
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_or_tf(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch or TensorFlow.
|
||||
|
||||
These tests are skipped when neither PyTorch not TensorFlow is installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
|
||||
test_case
|
||||
)
|
||||
|
||||
|
||||
def require_intel_extension_for_pytorch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Intel Extension for PyTorch.
|
||||
|
||||
Reference in New Issue
Block a user