From 744295636116eac1c0b84e23e9b3cab90886a45d Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 12 Jul 2019 11:26:16 +0200 Subject: [PATCH] save config file --- pytorch_transformers/modeling_utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 9ca3a3d090..bb2b82b41c 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -48,6 +48,17 @@ class PretrainedConfig(object): self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.torchscript = kwargs.pop('torchscript', False) + def save_pretrained(self, save_directory): + """ Save a configuration file to a directory, so that it + can be re-loaded using the `from_pretrained(save_directory)` class method. + """ + assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): """ @@ -248,12 +259,13 @@ class PreTrainedModel(nn.Module): # Only save the model it-self if we are using distributed training model_to_save = self.module if hasattr(self, 'module') else self + # Save configuration file + model_to_save.config.save_pretrained(save_directory) + # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) - output_config_file = os.path.join(save_directory, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) - model_to_save.config.to_json_file(output_config_file) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):