unified tokenizer api and serialization + tests
This commit is contained in:
@@ -169,6 +169,22 @@ class PreTrainedModel(nn.Module):
|
||||
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_to_prune._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model with its configuration file to a directory, so that it
|
||||
can be re-loaded using the `from_pretrained(save_directory)` class method.
|
||||
"""
|
||||
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
||||
|
||||
# Only save the model it-self if we are using distributed training
|
||||
model_to_save = self.module if hasattr(self, 'module') else self
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
model_to_save.config.to_json_file(output_config_file)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user