TF ALBERT + TF Utilities + Fix warnings
This commit is contained in:
@@ -91,7 +91,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
self.config = config
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
Returns the model's input embeddings.
|
||||
|
||||
Returns:
|
||||
:obj:`tf.keras.layers.Layer`:
|
||||
A torch module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
@@ -100,8 +105,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Get model's output embeddings
|
||||
Return None if the model doesn't have output embeddings
|
||||
"""
|
||||
Returns the model's output embeddings.
|
||||
|
||||
Returns:
|
||||
:obj:`tf.keras.layers.Layer`:
|
||||
A torch module mapping hidden states to vocabulary.
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
|
||||
Reference in New Issue
Block a user