Fix template (#9512)
This commit is contained in:
@@ -462,7 +462,7 @@ class TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(tf.keras.layers.Lay
|
|||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.input_embeddings.word_embeddings
|
return self.input_embeddings
|
||||||
|
|
||||||
def set_output_embeddings(self, value):
|
def set_output_embeddings(self, value):
|
||||||
self.input_embeddings.word_embeddings = value
|
self.input_embeddings.word_embeddings = value
|
||||||
|
|||||||
Reference in New Issue
Block a user