Fix template (#9512)

This commit is contained in:
Julien Plu
2021-01-11 14:03:28 +01:00
committed by GitHub
parent d415882b41
commit 1e3c362235

View File

@@ -462,7 +462,7 @@ class TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(tf.keras.layers.Lay
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.input_embeddings.word_embeddings return self.input_embeddings
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value self.input_embeddings.word_embeddings = value