From abc400b06a8ab26cd438b6e9add3aad082ffc48f Mon Sep 17 00:00:00 2001 From: Thomas Wang <24695242+thomasw21@users.noreply.github.com> Date: Tue, 21 Jun 2022 20:26:36 +0200 Subject: [PATCH] Add final_layer_norm to OPT model (#17785) * Add final_layer_norm to OPT model * Add JAX and TF version * Fix Keras name * Woops * Allow for non breaking change * Apply suggestions from code review * add tests Co-authored-by: Patrick von Platen --- src/transformers/models/opt/configuration_opt.py | 6 ++++++ ...ert_opt_original_pytorch_checkpoint_to_pytorch.py | 4 ++-- src/transformers/models/opt/modeling_flax_opt.py | 11 +++++++++++ src/transformers/models/opt/modeling_opt.py | 12 +++++++++++- src/transformers/models/opt/modeling_tf_opt.py | 11 +++++++++++ tests/models/opt/test_modeling_flax_opt.py | 8 ++++---- tests/models/opt/test_modeling_opt.py | 8 ++++---- tests/models/opt/test_modeling_tf_opt.py | 8 ++++---- 8 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/opt/configuration_opt.py b/src/transformers/models/opt/configuration_opt.py index eb7c8e0208..a101bb3e86 100644 --- a/src/transformers/models/opt/configuration_opt.py +++ b/src/transformers/models/opt/configuration_opt.py @@ -102,6 +102,7 @@ class OPTConfig(PretrainedConfig): ffn_dim=3072, max_position_embeddings=2048, do_layer_norm_before=True, + _remove_final_layer_norm=False, word_embed_proj_dim=None, dropout=0.1, attention_dropout=0.0, @@ -137,3 +138,8 @@ class OPTConfig(PretrainedConfig): self.layerdrop = layerdrop self.use_cache = use_cache self.do_layer_norm_before = do_layer_norm_before + + # Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + self._remove_final_layer_norm = _remove_final_layer_norm diff --git a/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py index 5992dc7e9a..ec1749daef 100644 --- a/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py @@ -37,8 +37,6 @@ def load_checkpoint(checkpoint_path): # pop unnecessary weights keys_to_delete = [ "decoder.version", - "decoder.layer_norm.weight", - "decoder.layer_norm.bias", "decoder.output_projection.weight", ] for key in keys_to_delete: @@ -48,6 +46,8 @@ def load_checkpoint(checkpoint_path): keys_to_rename = { "decoder.project_in_dim.weight": "decoder.project_in.weight", "decoder.project_out_dim.weight": "decoder.project_out.weight", + "decoder.layer_norm.weight": "decoder.final_layer_norm.weight", + "decoder.layer_norm.bias": "decoder.final_layer_norm.bias", } for old_key, new_key in keys_to_rename.items(): if old_key in sd: diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py index f84d56b0d8..5762fae14b 100644 --- a/src/transformers/models/opt/modeling_flax_opt.py +++ b/src/transformers/models/opt/modeling_flax_opt.py @@ -452,6 +452,14 @@ class FlaxOPTDecoder(nn.Module): self.project_in = None self.project_out = None + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + else: + self.final_layer_norm = None + self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype) def __call__( @@ -487,6 +495,9 @@ class FlaxOPTDecoder(nn.Module): output_hidden_states=output_hidden_states, ) + if self.final_layer_norm is not None: + hidden_state = self.final_layer_norm(hidden_state) + if self.project_out is not None: hidden_state = self.project_out(hidden_state) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 9e74c98d43..a40dcde0ac 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -492,7 +492,14 @@ class OPTDecoder(OPTPreTrainedModel): else: self.project_in = None - self.layer_norm = None + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + else: + self.final_layer_norm = None + self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False @@ -688,6 +695,9 @@ class OPTDecoder(OPTPreTrainedModel): if output_attentions: all_self_attns += (layer_outputs[1],) + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + if self.project_out is not None: hidden_states = self.project_out(hidden_states) diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py index 4353020485..89c731b4d5 100644 --- a/src/transformers/models/opt/modeling_tf_opt.py +++ b/src/transformers/models/opt/modeling_tf_opt.py @@ -506,6 +506,14 @@ class TFOPTDecoder(tf.keras.layers.Layer): name="embed_positions", ) + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm") + else: + self.final_layer_norm = None + if config.word_embed_proj_dim != config.hidden_size: self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False) self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False) @@ -681,6 +689,9 @@ class TFOPTDecoder(tf.keras.layers.Layer): if output_attentions: all_self_attns += (layer_self_attn,) + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + if self.project_out is not None: hidden_states = self.project_out(hidden_states) diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py index 17dce9eace..208ea0c0d7 100644 --- a/tests/models/opt/test_modeling_flax_opt.py +++ b/tests/models/opt/test_modeling_flax_opt.py @@ -292,10 +292,10 @@ class FlaxOPTGenerationTest(unittest.TestCase): model_id = "facebook/opt-125m" EXPECTED_OUTPUTS = [ - "Today is a beautiful day and I want everyone", - "In the city of Rome Canaver Canaver Canaver Canaver", - "Paris is the capital of France and Parisdylib", - "Computers and mobile phones have taken precedence over", + "Today is a beautiful day and I want to", + "In the city of New York, the city", + "Paris is the capital of France and the capital", + "Computers and mobile phones have taken over the", ] predicted_outputs = [] diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 8018d05f09..4ebb8e5919 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -344,10 +344,10 @@ class OPTGenerationTest(unittest.TestCase): model_id = "facebook/opt-125m" EXPECTED_OUTPUTS = [ - "Today is a beautiful day and I want everyone", - "In the city of Rome Canaver Canaver Canaver Canaver", - "Paris is the capital of France and Parisdylib", - "Computers and mobile phones have taken precedence over", + "Today is a beautiful day and I want to", + "In the city of New York, the city", + "Paris is the capital of France and the capital", + "Computers and mobile phones have taken over the", ] predicted_outputs = [] diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py index d34d4f0fc8..287b3ce319 100644 --- a/tests/models/opt/test_modeling_tf_opt.py +++ b/tests/models/opt/test_modeling_tf_opt.py @@ -330,10 +330,10 @@ class TFOPTGenerationTest(unittest.TestCase): model_id = "facebook/opt-125m" EXPECTED_OUTPUTS = [ - "Today is a beautiful day and I want everyone", - "In the city of Rome Canaver Canaver Canaver Canaver", - "Paris is the capital of France and Parisdylib", - "Computers and mobile phones have taken precedence over", + "Today is a beautiful day and I want to", + "In the city of New York, the city", + "Paris is the capital of France and the capital", + "Computers and mobile phones have taken over the", ] predicted_outputs = []