switch from properties to methods

This commit is contained in:
thomwolf
2019-11-04 15:34:10 +01:00
parent 9b45d0f878
commit 1724cee8c4
12 changed files with 70 additions and 75 deletions

View File

@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
def base_model(self):
return getattr(self, self.base_model_prefix, self)
@property
def input_embeddings(self):
def get_input_embeddings(self):
""" Get model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
return base_model.input_embeddings
if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
@property
def output_embeddings(self):
def set_input_embeddings(self, value):
""" Set model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return None # Overwrite for models with output embeddings
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
if self.output_embeddings is not None:
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings)
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of weither we are using TorchScript or not
@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.input_embeddings
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
return self.input_embeddings
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module.