Make quality scripts work when one backend is missing. (#11573)
* Make quality scripts work when one backend is missing. * Check env variable is properly set * Add default * With print statements * Fix typo * Set env variable * Remove debug code
This commit is contained in:
@@ -391,6 +391,8 @@ jobs:
|
|||||||
docker:
|
docker:
|
||||||
- image: circleci/python:3.6
|
- image: circleci/python:3.6
|
||||||
resource_class: medium
|
resource_class: medium
|
||||||
|
environment:
|
||||||
|
TRANSFORMERS_IS_CI: yes
|
||||||
parallelism: 1
|
parallelism: 1
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
|
|||||||
@@ -17,8 +17,11 @@ import importlib
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||||
|
from transformers.file_utils import ENV_VARS_TRUE_VALUES
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
|
|
||||||
|
|
||||||
@@ -250,15 +253,18 @@ def check_all_models_are_tested():
|
|||||||
def get_all_auto_configured_models():
|
def get_all_auto_configured_models():
|
||||||
"""Return the list of all models in at least one auto class."""
|
"""Return the list of all models in at least one auto class."""
|
||||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
if is_torch_available():
|
||||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
if is_tf_available():
|
||||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
if is_flax_available():
|
||||||
|
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||||
|
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
|
||||||
|
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||||
return [cls.__name__ for cls in result]
|
return [cls.__name__ for cls in result]
|
||||||
|
|
||||||
|
|
||||||
@@ -289,6 +295,27 @@ def check_models_are_auto_configured(module, all_auto_models):
|
|||||||
|
|
||||||
def check_all_models_are_auto_configured():
|
def check_all_models_are_auto_configured():
|
||||||
"""Check all models are each in an auto class."""
|
"""Check all models are each in an auto class."""
|
||||||
|
missing_backends = []
|
||||||
|
if not is_torch_available():
|
||||||
|
missing_backends.append("PyTorch")
|
||||||
|
if not is_tf_available():
|
||||||
|
missing_backends.append("TensorFlow")
|
||||||
|
if not is_flax_available():
|
||||||
|
missing_backends.append("Flax")
|
||||||
|
if len(missing_backends) > 0:
|
||||||
|
missing = ", ".join(missing_backends)
|
||||||
|
if os.getenv("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
||||||
|
raise Exception(
|
||||||
|
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
||||||
|
f"Transformers repo, the following are missing: {missing}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"Full quality checks require all backends to be installed (with `pip install -e .[dev]` in the "
|
||||||
|
f"Transformers repo, the following are missing: {missing}. While it's probably fine as long as you "
|
||||||
|
"didn't make any change in one of those backends modeling files, you should probably execute the "
|
||||||
|
"command above to be on the safe side."
|
||||||
|
)
|
||||||
modules = get_model_modules()
|
modules = get_model_modules()
|
||||||
all_auto_models = get_all_auto_configured_models()
|
all_auto_models = get_all_auto_configured_models()
|
||||||
failures = []
|
failures = []
|
||||||
|
|||||||
Reference in New Issue
Block a user