diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index feffe7a585..7ff9a10eb1 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -305,7 +305,7 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer): # Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->{{cookiecutter.camelcase_modelname}} class TF{{cookiecutter.camelcase_modelname}}Output(tf.keras.layers.Layer): - def __init__(self, config: BertConfig, **kwargs): + def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, **kwargs): super().__init__(**kwargs) self.dense = tf.keras.layers.Dense(