TF version compatibility fixes (#23663)
* New TF version compatibility fixes * Remove dummy print statement, move expand_1d * Make a proper framework inference function * Make a proper framework inference function * ValueError -> TypeError
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
import collections
|
||||
import csv
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
@@ -36,7 +35,7 @@ from ..image_processing_utils import BaseImageProcessor
|
||||
from ..modelcard import ModelCard
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available, logging
|
||||
from ..utils import ModelOutput, add_end_docstrings, infer_framework, is_tf_available, is_torch_available, logging
|
||||
|
||||
|
||||
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
|
||||
@@ -278,7 +277,7 @@ def infer_framework_load_model(
|
||||
if isinstance(model, str):
|
||||
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
||||
|
||||
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
||||
framework = infer_framework(model.__class__)
|
||||
return framework, model
|
||||
|
||||
|
||||
@@ -351,7 +350,7 @@ def get_framework(model, revision: Optional[str] = None):
|
||||
except OSError:
|
||||
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||
|
||||
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
||||
framework = infer_framework(model.__class__)
|
||||
return framework
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user