Fix template (#9512)
This commit is contained in:
@@ -462,7 +462,7 @@ class TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(tf.keras.layers.Lay
|
||||
super().build(input_shape)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.input_embeddings.word_embeddings
|
||||
return self.input_embeddings
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.input_embeddings.word_embeddings = value
|
||||
|
||||
Reference in New Issue
Block a user