Flax testing should not run the full torch test suite (#10725)

* make flax tests pytorch independent

* fix typo

* finish

* improve circle ci

* fix return tensors

* correct flax test

* re-add sentencepiece

* last tokenizer fixes

* finish maybe now
This commit is contained in:
Patrick von Platen
2021-03-16 08:05:37 +03:00
committed by GitHub
parent 87d685b8a9
commit 9f8619c6aa
9 changed files with 94 additions and 14 deletions

View File

@@ -80,6 +80,7 @@ def parse_int_from_env(key, default=None):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
@@ -105,6 +106,25 @@ def is_pt_tf_cross_test(test_case):
return pytest.mark.is_pt_tf_cross_test()(test_case)
def is_pt_flax_cross_test(test_case):
"""
Decorator marking a test as a test that control interactions between PyTorch and Flax
PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
"""
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
return unittest.skip("test is PT+FLAX test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_pt_flax_cross_test()(test_case)
def is_pipeline_test(test_case):
"""
Decorator marking a test as a pipeline test.