[Flax] Fixes typo in Bart based Flax Models (#13565)

This commit is contained in:
Bhadresh Savani
2021-09-15 11:03:52 +05:30
committed by GitHub
parent 7bd16b8776
commit 3fbb55c757
5 changed files with 20 additions and 20 deletions

View File

@@ -1432,7 +1432,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
@@ -1459,7 +1459,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
@@ -1541,7 +1541,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = Flax{{cookiecutter.camelcase_modelname}}Attention(
@@ -1598,7 +1598,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states