[TF models] Common attributes as per #1721
This commit is contained in:
@@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
def get_input_embeddings(self):
|
||||
""" Get model's input embeddings
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
return base_model.get_input_embeddings()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
""" Get model's output embeddings
|
||||
Return None if the model doesn't have output embeddings
|
||||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
|
||||
Reference in New Issue
Block a user