save config file
This commit is contained in:
@@ -48,6 +48,17 @@ class PretrainedConfig(object):
|
|||||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||||
self.torchscript = kwargs.pop('torchscript', False)
|
self.torchscript = kwargs.pop('torchscript', False)
|
||||||
|
|
||||||
|
def save_pretrained(self, save_directory):
|
||||||
|
""" Save a configuration file to a directory, so that it
|
||||||
|
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||||
|
"""
|
||||||
|
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||||
|
|
||||||
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
|
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||||
|
|
||||||
|
self.to_json_file(output_config_file)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -248,12 +259,13 @@ class PreTrainedModel(nn.Module):
|
|||||||
# Only save the model it-self if we are using distributed training
|
# Only save the model it-self if we are using distributed training
|
||||||
model_to_save = self.module if hasattr(self, 'module') else self
|
model_to_save = self.module if hasattr(self, 'module') else self
|
||||||
|
|
||||||
|
# Save configuration file
|
||||||
|
model_to_save.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
|
||||||
|
|
||||||
torch.save(model_to_save.state_dict(), output_model_file)
|
torch.save(model_to_save.state_dict(), output_model_file)
|
||||||
model_to_save.config.to_json_file(output_config_file)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user