Fix circular import in onnx.utils (#17577)

* Fix circular import in onnx.utils

* Add comment for test fetcher

* Here too

* Style
This commit is contained in:
Sylvain Gugger
2022-06-07 08:00:36 -04:00
committed by GitHub
parent 9aa230aa2f
commit b6a65ae52a

View File

@@ -14,9 +14,11 @@
from ctypes import c_float, sizeof
from enum import Enum
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
if TYPE_CHECKING:
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
class ParameterFormat(Enum):
@@ -66,7 +68,7 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
return num_parameters * dtype.size
def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]:
def get_preprocessor(model_name: str) -> Optional[Union["AutoTokenizer", "AutoFeatureExtractor", "AutoProcessor"]]:
"""
Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
@@ -79,6 +81,9 @@ def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatu
returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
`None` if no preprocessor is found.
"""
# Avoid circular imports by only importing this here.
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
try:
return AutoProcessor.from_pretrained(model_name)
except (ValueError, OSError, KeyError):