switch from properties to methods
This commit is contained in:
@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
|
||||
def base_model(self):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
return base_model.input_embeddings
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
def set_input_embeddings(self, value):
|
||||
""" Set model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Get model's output embeddings
|
||||
Return None if the model doesn't have output embeddings
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
if self.output_embeddings is not None:
|
||||
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings)
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
if output_embeddings is not None:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
|
||||
return model_embeds
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.input_embeddings
|
||||
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
return self.input_embeddings
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
|
||||
Reference in New Issue
Block a user