TF version compatibility fixes (#23663)

* New TF version compatibility fixes

* Remove dummy print statement, move expand_1d

* Make a proper framework inference function

* Make a proper framework inference function

* ValueError -> TypeError
This commit is contained in:
Matt
2023-05-23 16:42:11 +01:00
committed by GitHub
parent 42baa58f90
commit 876d9a32c6
5 changed files with 128 additions and 30 deletions

View File

@@ -15,7 +15,6 @@
import collections
import csv
import importlib
import inspect
import json
import os
import pickle
@@ -36,7 +35,7 @@ from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available, logging
from ..utils import ModelOutput, add_end_docstrings, infer_framework, is_tf_available, is_torch_available, logging
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
@@ -278,7 +277,7 @@ def infer_framework_load_model(
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
framework = infer_framework(model.__class__)
return framework, model
@@ -351,7 +350,7 @@ def get_framework(model, revision: Optional[str] = None):
except OSError:
model = TFAutoModel.from_pretrained(model, revision=revision)
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
framework = infer_framework(model.__class__)
return framework