adding more tests on TF and pytorch serialization - updating configuration for better serialization
This commit is contained in:
@@ -153,7 +153,7 @@ class PretrainedConfig(object):
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
|
||||
if hasattr(config, 'pruned_heads'):
|
||||
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
|
||||
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
||||
|
||||
# Update config with kwargs if needed
|
||||
to_remove = []
|
||||
@@ -164,7 +164,7 @@ class PretrainedConfig(object):
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info("Model config %s", config)
|
||||
logger.info("Model config %s", str(config))
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user