From 6587cf9f8448b5573cf4a1c639ef4857472d1da0 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 14 Dec 2020 00:39:55 -0500 Subject: [PATCH] Patch *ForCausalLM model (#9092) --- .../modeling_tf_{{cookiecutter.lowercase_modelname}}.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 15ac9571ba..109b9f310b 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -872,6 +872,12 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca def get_output_embeddings(self): return self.{{cookiecutter.lowercase_modelname}}.embeddings + def get_output_layer_with_bias(self): + return self.mlm.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="{{cookiecutter.checkpoint_identifier}}",