Fix obvious typos in flax decoder impl (#17279)

Change config.encoder_ffn_dim -> config.decoder_ffn_dim for decoder.
This commit is contained in:
cloudhan
2022-05-16 19:08:04 +08:00
committed by GitHub
parent ee393c009a
commit e86faecfd4
7 changed files with 9 additions and 9 deletions

View File

@@ -1996,7 +1996,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
self.config.decoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
@@ -2997,10 +2997,10 @@ FLAX_{{cookiecutter.uppercase_modelname}}_CONDITIONAL_GENERATION_DOCSTRING = """
```python
>>> import jax
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration
>>> model = Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> input_ids = tokenizer([TXT], return_tensors='np')['input_ids']