ALBERT Modeling + required changes to utilities

This commit is contained in:
Lysandre
2020-01-15 14:20:17 -05:00
committed by Lysandre Debut
parent f81b6c95f2
commit 00df3d4de0
4 changed files with 259 additions and 166 deletions

View File

@@ -114,7 +114,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return getattr(self, self.base_model_prefix, self)
def get_input_embeddings(self):
""" Get model's input embeddings
"""
Returns the model's input embeddings.
Returns:
:obj:`nn.Module`:
A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
@@ -123,7 +128,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
raise NotImplementedError
def set_input_embeddings(self, value):
""" Set model's input embeddings
"""
Set model's input embeddings
Args:
value (:obj:`nn.Module`):
A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
@@ -132,14 +142,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
Returns the model's output embeddings.
Returns:
:obj:`nn.Module`:
A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
Tie the weights between the input embeddings and the output embeddings.
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
the weights instead.
"""
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None: