From e9d0bc027a911059e6e01f78f8005036d2880b06 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 18 Apr 2020 02:07:18 +0200 Subject: [PATCH] [Config, Serialization] more readable config serialization (#3797) * better config serialization * finish configuration utils --- src/transformers/configuration_utils.py | 43 ++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 0e4a58e7ca..2066d83d65 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -141,7 +141,7 @@ class PretrainedConfig(object): # If we save using the predefined names, we can load using `from_pretrained` output_config_file = os.path.join(save_directory, CONFIG_NAME) - self.to_json_file(output_config_file) + self.to_json_file(output_config_file, use_diff=True) logger.info("Configuration saved in {}".format(output_config_file)) @classmethod @@ -353,6 +353,29 @@ class PretrainedConfig(object): def __repr__(self): return "{} {}".format(self.__class__.__name__, self.to_json_string()) + def to_diff_dict(self): + """ + Removes all attributes from config which correspond to the default + config attributes for better readability and serializes to a Python + dictionary. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if key not in default_config_dict or value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + def to_dict(self): """ Serializes this instance to a Python dictionary. @@ -365,25 +388,35 @@ class PretrainedConfig(object): output["model_type"] = self.__class__.model_type return output - def to_json_string(self): + def to_json_string(self, use_diff=True): """ Serializes this instance to a JSON string. + Args: + use_diff (:obj:`bool`): + If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string. + Returns: :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format. """ - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - def to_json_file(self, json_file_path): + def to_json_file(self, json_file_path, use_diff=True): """ Save this instance to a json file. Args: json_file_path (:obj:`string`): Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (:obj:`bool`): + If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file. """ with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) + writer.write(self.to_json_string(use_diff=use_diff)) def update(self, config_dict: Dict): """