Inheritance-based framework detection (#21784)

This commit is contained in:
Joao Gante
2023-02-27 15:31:55 +00:00
committed by GitHub
parent 7811bf7e73
commit 92dfceb124
3 changed files with 60 additions and 35 deletions

View File

@@ -15,6 +15,7 @@
import collections
import csv
import importlib
import inspect
import json
import os
import pickle
@@ -269,7 +270,7 @@ def infer_framework_load_model(
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
return framework, model
@@ -342,7 +343,7 @@ def get_framework(model, revision: Optional[str] = None):
except OSError:
model = TFAutoModel.from_pretrained(model, revision=revision)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
return framework