From e86faecfd43b8a8e296f14b4dbe21a68f3e2a4ed Mon Sep 17 00:00:00 2001 From: cloudhan Date: Mon, 16 May 2022 19:08:04 +0800 Subject: [PATCH] Fix obvious typos in flax decoder impl (#17279) Change config.encoder_ffn_dim -> config.decoder_ffn_dim for decoder. --- src/transformers/models/bart/modeling_flax_bart.py | 2 +- .../models/blenderbot/modeling_flax_blenderbot.py | 2 +- .../blenderbot_small/modeling_flax_blenderbot_small.py | 2 +- src/transformers/models/marian/modeling_flax_marian.py | 2 +- src/transformers/models/mbart/modeling_flax_mbart.py | 2 +- src/transformers/models/pegasus/modeling_flax_pegasus.py | 2 +- .../modeling_flax_{{cookiecutter.lowercase_modelname}}.py | 6 +++--- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 55d32a3f06..5704147872 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -537,7 +537,7 @@ class FlaxBartDecoderLayer(nn.Module): ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 7f30878772..a75fe4d5b7 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -528,7 +528,7 @@ class FlaxBlenderbotDecoderLayer(nn.Module): ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index c08e277282..ddace51e7e 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -541,7 +541,7 @@ class FlaxBlenderbotSmallDecoderLayer(nn.Module): ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index 8fea39e19a..da2e4a1fe5 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -551,7 +551,7 @@ class FlaxMarianDecoderLayer(nn.Module): ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index 141d2b1041..7cb52033b7 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -550,7 +550,7 @@ class FlaxMBartDecoderLayer(nn.Module): ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index 81276dcd2a..303d005571 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -544,7 +544,7 @@ class FlaxPegasusDecoderLayer(nn.Module): ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 43fbad2495..451dc03f62 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -1996,7 +1996,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.fc1 = nn.Dense( - self.config.encoder_ffn_dim, + self.config.decoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) @@ -2997,10 +2997,10 @@ FLAX_{{cookiecutter.uppercase_modelname}}_CONDITIONAL_GENERATION_DOCSTRING = """ ```python >>> import jax >>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration - + >>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}') >>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}') - + >>> TXT = "My friends are but they eat too many carbs." >>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids']