TF ALBERT + TF Utilities + Fix warnings

This commit is contained in:
Lysandre
2020-01-15 15:50:30 -05:00
committed by Lysandre Debut
parent 00df3d4de0
commit 3922a2497e
6 changed files with 148 additions and 126 deletions

View File

@@ -91,7 +91,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
self.config = config
def get_input_embeddings(self):
""" Get model's input embeddings
"""
Returns the model's input embeddings.
Returns:
:obj:`tf.keras.layers.Layer`:
A torch module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
@@ -100,8 +105,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
Returns the model's output embeddings.
Returns:
:obj:`tf.keras.layers.Layer`:
A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings