Explicit config_class instead of module inspection
This commit is contained in:
@@ -50,11 +50,13 @@ class TFModelUtilsMixin:
|
||||
def keras_serializable(cls):
|
||||
initializer = cls.__init__
|
||||
|
||||
config_class = getattr(cls, "config_class", None)
|
||||
if config_class is None:
|
||||
raise AttributeError("Must set `config_class` to use @keras_serializable")
|
||||
|
||||
def wrapped_init(self, config, *args, **kwargs):
|
||||
if isinstance(config, dict):
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
|
||||
config = config_class.from_dict(config)
|
||||
initializer(self, config, *args, **kwargs)
|
||||
self._transformers_config = config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user