From f91810da88d298030518a592d921e2e3725fe117 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 13 Jun 2023 14:28:08 +0100 Subject: [PATCH] Safely import pytest in testing_utils.py (#24241) --- src/transformers/testing_utils.py | 31 ++++++++++++++++---------- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 +++++ 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index f703575190..ce2dec08e8 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -36,18 +36,6 @@ from unittest import mock import huggingface_hub import requests -from _pytest.doctest import ( - Module, - _get_checker, - _get_continue_on_failure, - _get_runner, - _is_mocked, - _patch_unwrap_mock_aware, - get_optionflags, - import_path, -) -from _pytest.outcomes import skip -from pytest import DoctestItem from transformers import logging as transformers_logging @@ -83,6 +71,7 @@ from .utils import ( is_phonemizer_available, is_pyctcdecode_available, is_pytesseract_available, + is_pytest_available, is_pytorch_quantization_available, is_rjieba_available, is_safetensors_available, @@ -116,6 +105,24 @@ if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState +if is_pytest_available(): + from _pytest.doctest import ( + Module, + _get_checker, + _get_continue_on_failure, + _get_runner, + _is_mocked, + _patch_unwrap_mock_aware, + get_optionflags, + import_path, + ) + from _pytest.outcomes import skip + from pytest import DoctestItem +else: + Module = object + DoctestItem = object + + SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown" DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 3aa1f8aeb9..9fc9d7ee30 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -133,6 +133,7 @@ from .import_utils import ( is_py3nvml_available, is_pyctcdecode_available, is_pytesseract_available, + is_pytest_available, is_pytorch_quantization_available, is_rjieba_available, is_sacremoses_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 93e5e74bba..eb9cdcc67b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -105,6 +105,7 @@ _psutil_available = _is_package_available("psutil") _py3nvml_available = _is_package_available("py3nvml") _pyctcdecode_available = _is_package_available("pyctcdecode") _pytesseract_available = _is_package_available("pytesseract") +_pytest_available = _is_package_available("pytest") _pytorch_quantization_available = _is_package_available("pytorch_quantization") _rjieba_available = _is_package_available("rjieba") _sacremoses_available = _is_package_available("sacremoses") @@ -547,6 +548,10 @@ def is_pytesseract_available(): return _pytesseract_available +def is_pytest_available(): + return _pytest_available + + def is_spacy_available(): return _spacy_available