Explicit config_class instead of module inspection

This commit is contained in:
Gunnlaugur Thor Briem
2020-03-04 23:45:29 +00:00
parent 6fe1cc0874
commit 4f338ed407
8 changed files with 17 additions and 17 deletions

View File

@@ -50,11 +50,13 @@ class TFModelUtilsMixin:
def keras_serializable(cls):
initializer = cls.__init__
config_class = getattr(cls, "config_class", None)
if config_class is None:
raise AttributeError("Must set `config_class` to use @keras_serializable")
def wrapped_init(self, config, *args, **kwargs):
if isinstance(config, dict):
from transformers import AutoConfig
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
config = config_class.from_dict(config)
initializer(self, config, *args, **kwargs)
self._transformers_config = config