Add final_layer_norm to OPT model (#17785)

* Add final_layer_norm to OPT model

* Add JAX and TF version

* Fix Keras name

* Woops

* Allow for non breaking change

* Apply suggestions from code review

* add tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Thomas Wang
2022-06-21 20:26:36 +02:00
committed by GitHub
parent 52404cbad4
commit abc400b06a
8 changed files with 53 additions and 15 deletions

View File

@@ -102,6 +102,7 @@ class OPTConfig(PretrainedConfig):
ffn_dim=3072, ffn_dim=3072,
max_position_embeddings=2048, max_position_embeddings=2048,
do_layer_norm_before=True, do_layer_norm_before=True,
_remove_final_layer_norm=False,
word_embed_proj_dim=None, word_embed_proj_dim=None,
dropout=0.1, dropout=0.1,
attention_dropout=0.0, attention_dropout=0.0,
@@ -137,3 +138,8 @@ class OPTConfig(PretrainedConfig):
self.layerdrop = layerdrop self.layerdrop = layerdrop
self.use_cache = use_cache self.use_cache = use_cache
self.do_layer_norm_before = do_layer_norm_before self.do_layer_norm_before = do_layer_norm_before
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
self._remove_final_layer_norm = _remove_final_layer_norm

View File

@@ -37,8 +37,6 @@ def load_checkpoint(checkpoint_path):
# pop unnecessary weights # pop unnecessary weights
keys_to_delete = [ keys_to_delete = [
"decoder.version", "decoder.version",
"decoder.layer_norm.weight",
"decoder.layer_norm.bias",
"decoder.output_projection.weight", "decoder.output_projection.weight",
] ]
for key in keys_to_delete: for key in keys_to_delete:
@@ -48,6 +46,8 @@ def load_checkpoint(checkpoint_path):
keys_to_rename = { keys_to_rename = {
"decoder.project_in_dim.weight": "decoder.project_in.weight", "decoder.project_in_dim.weight": "decoder.project_in.weight",
"decoder.project_out_dim.weight": "decoder.project_out.weight", "decoder.project_out_dim.weight": "decoder.project_out.weight",
"decoder.layer_norm.weight": "decoder.final_layer_norm.weight",
"decoder.layer_norm.bias": "decoder.final_layer_norm.bias",
} }
for old_key, new_key in keys_to_rename.items(): for old_key, new_key in keys_to_rename.items():
if old_key in sd: if old_key in sd:

View File

@@ -452,6 +452,14 @@ class FlaxOPTDecoder(nn.Module):
self.project_in = None self.project_in = None
self.project_out = None self.project_out = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
else:
self.final_layer_norm = None
self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype) self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
def __call__( def __call__(
@@ -487,6 +495,9 @@ class FlaxOPTDecoder(nn.Module):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
if self.final_layer_norm is not None:
hidden_state = self.final_layer_norm(hidden_state)
if self.project_out is not None: if self.project_out is not None:
hidden_state = self.project_out(hidden_state) hidden_state = self.project_out(hidden_state)

View File

@@ -492,7 +492,14 @@ class OPTDecoder(OPTPreTrainedModel):
else: else:
self.project_in = None self.project_in = None
self.layer_norm = None # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False self.gradient_checkpointing = False
@@ -688,6 +695,9 @@ class OPTDecoder(OPTPreTrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states = self.project_out(hidden_states)

View File

@@ -506,6 +506,14 @@ class TFOPTDecoder(tf.keras.layers.Layer):
name="embed_positions", name="embed_positions",
) )
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
else:
self.final_layer_norm = None
if config.word_embed_proj_dim != config.hidden_size: if config.word_embed_proj_dim != config.hidden_size:
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False) self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False) self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
@@ -681,6 +689,9 @@ class TFOPTDecoder(tf.keras.layers.Layer):
if output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None: if self.project_out is not None:
hidden_states = self.project_out(hidden_states) hidden_states = self.project_out(hidden_states)

View File

@@ -292,10 +292,10 @@ class FlaxOPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m" model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone", "Today is a beautiful day and I want to",
"In the city of Rome Canaver Canaver Canaver Canaver", "In the city of New York, the city",
"Paris is the capital of France and Parisdylib", "Paris is the capital of France and the capital",
"Computers and mobile phones have taken precedence over", "Computers and mobile phones have taken over the",
] ]
predicted_outputs = [] predicted_outputs = []

View File

@@ -344,10 +344,10 @@ class OPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m" model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone", "Today is a beautiful day and I want to",
"In the city of Rome Canaver Canaver Canaver Canaver", "In the city of New York, the city",
"Paris is the capital of France and Parisdylib", "Paris is the capital of France and the capital",
"Computers and mobile phones have taken precedence over", "Computers and mobile phones have taken over the",
] ]
predicted_outputs = [] predicted_outputs = []

View File

@@ -330,10 +330,10 @@ class TFOPTGenerationTest(unittest.TestCase):
model_id = "facebook/opt-125m" model_id = "facebook/opt-125m"
EXPECTED_OUTPUTS = [ EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want everyone", "Today is a beautiful day and I want to",
"In the city of Rome Canaver Canaver Canaver Canaver", "In the city of New York, the city",
"Paris is the capital of France and Parisdylib", "Paris is the capital of France and the capital",
"Computers and mobile phones have taken precedence over", "Computers and mobile phones have taken over the",
] ]
predicted_outputs = [] predicted_outputs = []