From 607acd4fbd175feb458eda7317419cdee5d443fe Mon Sep 17 00:00:00 2001 From: DanielHesslow Date: Fri, 3 Jun 2022 10:56:37 +0200 Subject: [PATCH] Add Gated-SiLU to T5 (#17420) * Add gated-silu to t5 architecture to support UL2 * Fix error message * formatting * formatting again * refactor * fix classnames in _init_weights * remove is_gated * add test * fix test * Try without the test? * Add back the test. * Improve error message. Co-authored-by: Daniel Hesslow --- .../models/t5/configuration_t5.py | 16 ++++++++++++ .../models/t5/modeling_flax_t5.py | 21 +++++++-------- src/transformers/models/t5/modeling_t5.py | 26 ++++++++----------- src/transformers/models/t5/modeling_tf_t5.py | 19 ++++++-------- tests/models/t5/test_modeling_t5.py | 6 +++++ 5 files changed, 50 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index b09539c86d..a2bd03dfd7 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -116,6 +116,22 @@ class T5Config(PretrainedConfig): self.initializer_factor = initializer_factor self.feed_forward_proj = feed_forward_proj self.use_cache = use_cache + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer." + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + # for backwards compatibility + if feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + super().__init__( pad_token_id=pad_token_id, eos_token_id=eos_token_id, diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index a6e1da70bb..23f7436eab 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -87,7 +87,7 @@ class FlaxT5LayerNorm(nn.Module): return self.weight * hidden_states -class FlaxT5DenseReluDense(nn.Module): +class FlaxT5DenseActDense(nn.Module): config: T5Config dtype: jnp.dtype = jnp.float32 @@ -108,16 +108,17 @@ class FlaxT5DenseReluDense(nn.Module): dtype=self.dtype, ) self.dropout = nn.Dropout(self.config.dropout_rate) + self.act = ACT2FN[self.config.dense_act_fn] def __call__(self, hidden_states, deterministic=True): hidden_states = self.wi(hidden_states) - hidden_states = jax.nn.relu(hidden_states) + hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.wo(hidden_states) return hidden_states -class FlaxT5DenseGatedGeluDense(nn.Module): +class FlaxT5DenseGatedActDense(nn.Module): config: T5Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation @@ -144,10 +145,10 @@ class FlaxT5DenseGatedGeluDense(nn.Module): dtype=self.dtype, ) self.dropout = nn.Dropout(self.config.dropout_rate) - self.gelu_act = ACT2FN["gelu_new"] + self.act = ACT2FN[self.config.dense_act_fn] def __call__(self, hidden_states, deterministic): - hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states, deterministic=deterministic) @@ -160,14 +161,10 @@ class FlaxT5LayerFF(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - if self.config.feed_forward_proj == "relu": - self.DenseReluDense = FlaxT5DenseReluDense(self.config, dtype=self.dtype) - elif self.config.feed_forward_proj == "gated-gelu": - self.DenseReluDense = FlaxT5DenseGatedGeluDense(self.config, dtype=self.dtype) + if self.config.is_gated_act: + self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype) else: - raise ValueError( - f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" - ) + self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype) self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) self.dropout = nn.Dropout(self.config.dropout_rate) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 92b64ea7fb..bdf3026dac 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -276,33 +276,33 @@ except Exception: pass -class T5DenseReluDense(nn.Module): +class T5DenseActDense(nn.Module): def __init__(self, config: T5Config): super().__init__() self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout_rate) - self.relu_act = ACT2FN["relu"] + self.act = ACT2FN[config.dense_act_fn] def forward(self, hidden_states): hidden_states = self.wi(hidden_states) - hidden_states = self.relu_act(hidden_states) + hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states -class T5DenseGatedGeluDense(nn.Module): +class T5DenseGatedActDense(nn.Module): def __init__(self, config: T5Config): super().__init__() self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout_rate) - self.gelu_act = ACT2FN["gelu_new"] + self.act = ACT2FN[config.dense_act_fn] def forward(self, hidden_states): - hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) @@ -313,14 +313,10 @@ class T5DenseGatedGeluDense(nn.Module): class T5LayerFF(nn.Module): def __init__(self, config: T5Config): super().__init__() - if config.feed_forward_proj == "relu": - self.DenseReluDense = T5DenseReluDense(config) - elif config.feed_forward_proj == "gated-gelu": - self.DenseReluDense = T5DenseGatedGeluDense(config) + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) else: - raise ValueError( - f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" - ) + self.DenseReluDense = T5DenseActDense(config) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -769,7 +765,7 @@ class T5PreTrainedModel(PreTrainedModel): module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) - elif isinstance(module, T5DenseReluDense): + elif isinstance(module, T5DenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 @@ -779,7 +775,7 @@ class T5PreTrainedModel(PreTrainedModel): module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() - elif isinstance(module, T5DenseGatedGeluDense): + elif isinstance(module, T5DenseGatedActDense): module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: module.wi_0.bias.data.zero_() diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 12ac789c6b..77a65557da 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -93,7 +93,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer): return self.weight * hidden_states -class TFT5DenseReluDense(tf.keras.layers.Layer): +class TFT5DenseActDense(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) wi_initializer = tf.keras.initializers.RandomNormal( @@ -109,7 +109,7 @@ class TFT5DenseReluDense(tf.keras.layers.Layer): config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer ) # Update init weights as in flax self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - self.act = tf.keras.activations.relu + self.act = get_tf_activation(config.dense_act_fn) def call(self, hidden_states, training=False): hidden_states = self.wi(hidden_states) @@ -119,7 +119,7 @@ class TFT5DenseReluDense(tf.keras.layers.Layer): return hidden_states -class TFT5GatedGeluDense(tf.keras.layers.Layer): +class TFT5DenseGatedActDense(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) wi_initializer = tf.keras.initializers.RandomNormal( @@ -138,7 +138,7 @@ class TFT5GatedGeluDense(tf.keras.layers.Layer): config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer ) # Update init weights as in flax self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - self.act = get_tf_activation("gelu_new") + self.act = get_tf_activation(config.dense_act_fn) def call(self, hidden_states, training=False): hidden_gelu = self.act(self.wi_0(hidden_states)) @@ -152,14 +152,11 @@ class TFT5GatedGeluDense(tf.keras.layers.Layer): class TFT5LayerFF(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - if config.feed_forward_proj == "relu": - self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense") - elif config.feed_forward_proj == "gated-gelu": - self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense") + if config.is_gated_act: + self.DenseReluDense = TFT5DenseGatedActDense(config, name="DenseReluDense") else: - raise ValueError( - f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" - ) + self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense") + self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 035e00c05c..4485e65eec 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -539,6 +539,12 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config.feed_forward_proj = "gated-gelu" self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + def test_config_and_model_silu_gated(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + config.feed_forward_proj = "gated-silu" + self.model_tester.create_and_check_model(*config_and_inputs) + def test_with_lm_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_with_lm_head(*config_and_inputs)