Fix circular import in onnx.utils (#17577)
* Fix circular import in onnx.utils * Add comment for test fetcher * Here too * Style
This commit is contained in:
@@ -14,9 +14,11 @@
|
|||||||
|
|
||||||
from ctypes import c_float, sizeof
|
from ctypes import c_float, sizeof
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
|
||||||
|
|
||||||
|
|
||||||
class ParameterFormat(Enum):
|
class ParameterFormat(Enum):
|
||||||
@@ -66,7 +68,7 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
|
|||||||
return num_parameters * dtype.size
|
return num_parameters * dtype.size
|
||||||
|
|
||||||
|
|
||||||
def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]:
|
def get_preprocessor(model_name: str) -> Optional[Union["AutoTokenizer", "AutoFeatureExtractor", "AutoProcessor"]]:
|
||||||
"""
|
"""
|
||||||
Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
|
Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
|
||||||
|
|
||||||
@@ -79,6 +81,9 @@ def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatu
|
|||||||
returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
|
returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
|
||||||
`None` if no preprocessor is found.
|
`None` if no preprocessor is found.
|
||||||
"""
|
"""
|
||||||
|
# Avoid circular imports by only importing this here.
|
||||||
|
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer # tests_ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return AutoProcessor.from_pretrained(model_name)
|
return AutoProcessor.from_pretrained(model_name)
|
||||||
except (ValueError, OSError, KeyError):
|
except (ValueError, OSError, KeyError):
|
||||||
|
|||||||
Reference in New Issue
Block a user