[PretrainedConfig] Fix save pretrained config for edge case (#7943)

* fix config save

* add test

* add config class variable and another test

* line break

* fix fsmt and typo

* god am I making many errors today :-/

* Update src/transformers/configuration_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-10-22 15:39:01 +02:00
committed by GitHub
parent cc2e312ca3
commit f34372a9ff
6 changed files with 34 additions and 4 deletions

View File

@@ -41,6 +41,10 @@ class PretrainedConfig(object):
Class attributes (overridden by derived classes)
- **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
recreate the correct object in :class:`~transformers.AutoConfig`.
- **is_composition** (:obj:`bool`): Whether the config class is composed of multiple
sub-configs. In this case the config has to be initialized from two or more configs of
type :class:`~transformers.PretrainedConfig` like: :class:`~transformers.EncoderDecoderConfig` or
:class:`~RagConfig`.
Args:
name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`):
@@ -145,6 +149,7 @@ class PretrainedConfig(object):
use BFloat16 scalars (only used by some TensorFlow models).
"""
model_type: str = ""
is_composition: bool = False
def __init__(self, **kwargs):
# Attributes with defaults
@@ -476,11 +481,18 @@ class PretrainedConfig(object):
# get the default config dict
default_config_dict = PretrainedConfig().to_dict()
# get class specific config dict
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
serializable_config_dict = {}
# only serialize values that differ from the default config
for key, value in config_dict.items():
if key not in default_config_dict or value != default_config_dict[key]:
if (
key not in default_config_dict
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
):
serializable_config_dict[key] = value
return serializable_config_dict