diff --git a/src/transformers/onnx/utils.py b/src/transformers/onnx/utils.py index 83b18c6ab5..9672b0a96a 100644 --- a/src/transformers/onnx/utils.py +++ b/src/transformers/onnx/utils.py @@ -14,9 +14,11 @@ from ctypes import c_float, sizeof from enum import Enum -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union -from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer + +if TYPE_CHECKING: + from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore class ParameterFormat(Enum): @@ -66,7 +68,7 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm return num_parameters * dtype.size -def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]: +def get_preprocessor(model_name: str) -> Optional[Union["AutoTokenizer", "AutoFeatureExtractor", "AutoProcessor"]]: """ Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`. @@ -79,6 +81,9 @@ def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatu returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns `None` if no preprocessor is found. """ + # Avoid circular imports by only importing this here. + from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore + try: return AutoProcessor.from_pretrained(model_name) except (ValueError, OSError, KeyError):