[Config, Serialization] more readable config serialization (#3797)
* better config serialization * finish configuration utils
This commit is contained in:
committed by
GitHub
parent
8b63a01d95
commit
e9d0bc027a
@@ -141,7 +141,7 @@ class PretrainedConfig(object):
|
|||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||||
|
|
||||||
self.to_json_file(output_config_file)
|
self.to_json_file(output_config_file, use_diff=True)
|
||||||
logger.info("Configuration saved in {}".format(output_config_file))
|
logger.info("Configuration saved in {}".format(output_config_file))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -353,6 +353,29 @@ class PretrainedConfig(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{} {}".format(self.__class__.__name__, self.to_json_string())
|
return "{} {}".format(self.__class__.__name__, self.to_json_string())
|
||||||
|
|
||||||
|
def to_diff_dict(self):
|
||||||
|
"""
|
||||||
|
Removes all attributes from config which correspond to the default
|
||||||
|
config attributes for better readability and serializes to a Python
|
||||||
|
dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||||
|
"""
|
||||||
|
config_dict = self.to_dict()
|
||||||
|
|
||||||
|
# get the default config dict
|
||||||
|
default_config_dict = PretrainedConfig().to_dict()
|
||||||
|
|
||||||
|
serializable_config_dict = {}
|
||||||
|
|
||||||
|
# only serialize values that differ from the default config
|
||||||
|
for key, value in config_dict.items():
|
||||||
|
if key not in default_config_dict or value != default_config_dict[key]:
|
||||||
|
serializable_config_dict[key] = value
|
||||||
|
|
||||||
|
return serializable_config_dict
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Serializes this instance to a Python dictionary.
|
Serializes this instance to a Python dictionary.
|
||||||
@@ -365,25 +388,35 @@ class PretrainedConfig(object):
|
|||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self, use_diff=True):
|
||||||
"""
|
"""
|
||||||
Serializes this instance to a JSON string.
|
Serializes this instance to a JSON string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_diff (:obj:`bool`):
|
||||||
|
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
|
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||||
"""
|
"""
|
||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
if use_diff is True:
|
||||||
|
config_dict = self.to_diff_dict()
|
||||||
|
else:
|
||||||
|
config_dict = self.to_dict()
|
||||||
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
def to_json_file(self, json_file_path):
|
def to_json_file(self, json_file_path, use_diff=True):
|
||||||
"""
|
"""
|
||||||
Save this instance to a json file.
|
Save this instance to a json file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
json_file_path (:obj:`string`):
|
json_file_path (:obj:`string`):
|
||||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||||
|
use_diff (:obj:`bool`):
|
||||||
|
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
|
||||||
"""
|
"""
|
||||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||||
writer.write(self.to_json_string())
|
writer.write(self.to_json_string(use_diff=use_diff))
|
||||||
|
|
||||||
def update(self, config_dict: Dict):
|
def update(self, config_dict: Dict):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user