Fix obvious typos in flax decoder impl (#17279)
Change config.encoder_ffn_dim -> config.decoder_ffn_dim for decoder.
This commit is contained in:
@@ -537,7 +537,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
self.config.encoder_ffn_dim,
|
self.config.decoder_ffn_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -528,7 +528,7 @@ class FlaxBlenderbotDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
self.config.encoder_ffn_dim,
|
self.config.decoder_ffn_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -541,7 +541,7 @@ class FlaxBlenderbotSmallDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
self.config.encoder_ffn_dim,
|
self.config.decoder_ffn_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -551,7 +551,7 @@ class FlaxMarianDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
self.config.encoder_ffn_dim,
|
self.config.decoder_ffn_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -550,7 +550,7 @@ class FlaxMBartDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
self.config.encoder_ffn_dim,
|
self.config.decoder_ffn_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -544,7 +544,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
self.config.encoder_ffn_dim,
|
self.config.decoder_ffn_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1996,7 +1996,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
||||||
self.fc1 = nn.Dense(
|
self.fc1 = nn.Dense(
|
||||||
self.config.encoder_ffn_dim,
|
self.config.decoder_ffn_dim,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
||||||
)
|
)
|
||||||
@@ -2997,10 +2997,10 @@ FLAX_{{cookiecutter.uppercase_modelname}}_CONDITIONAL_GENERATION_DOCSTRING = """
|
|||||||
```python
|
```python
|
||||||
>>> import jax
|
>>> import jax
|
||||||
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
|
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
|
||||||
|
|
||||||
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
||||||
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
||||||
|
|
||||||
>>> TXT = "My friends are <mask> but they eat too many carbs."
|
>>> TXT = "My friends are <mask> but they eat too many carbs."
|
||||||
>>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids']
|
>>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids']
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user