ALBERT Modeling + required changes to utilities
This commit is contained in:
@@ -114,7 +114,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
return getattr(self, self.base_model_prefix, self)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
Returns the model's input embeddings.
|
||||
|
||||
Returns:
|
||||
:obj:`nn.Module`:
|
||||
A torch module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
@@ -123,7 +128,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
""" Set model's input embeddings
|
||||
"""
|
||||
Set model's input embeddings
|
||||
|
||||
Args:
|
||||
value (:obj:`nn.Module`):
|
||||
A module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
@@ -132,14 +142,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Get model's output embeddings
|
||||
Return None if the model doesn't have output embeddings
|
||||
"""
|
||||
Returns the model's output embeddings.
|
||||
|
||||
Returns:
|
||||
:obj:`nn.Module`:
|
||||
A torch module mapping hidden states to vocabulary.
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings.
|
||||
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
|
||||
the weights instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
if output_embeddings is not None:
|
||||
|
||||
Reference in New Issue
Block a user