Inheritance-based framework detection (#21784)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import collections
|
||||
import csv
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
@@ -269,7 +270,7 @@ def infer_framework_load_model(
|
||||
if isinstance(model, str):
|
||||
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
|
||||
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
||||
return framework, model
|
||||
|
||||
|
||||
@@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
|
||||
except OSError:
|
||||
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
|
||||
return framework
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user