Add final_layer_norm to OPT model (#17785)
* Add final_layer_norm to OPT model * Add JAX and TF version * Fix Keras name * Woops * Allow for non breaking change * Apply suggestions from code review * add tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -102,6 +102,7 @@ class OPTConfig(PretrainedConfig):
|
|||||||
ffn_dim=3072,
|
ffn_dim=3072,
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
do_layer_norm_before=True,
|
do_layer_norm_before=True,
|
||||||
|
_remove_final_layer_norm=False,
|
||||||
word_embed_proj_dim=None,
|
word_embed_proj_dim=None,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
@@ -137,3 +138,8 @@ class OPTConfig(PretrainedConfig):
|
|||||||
self.layerdrop = layerdrop
|
self.layerdrop = layerdrop
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.do_layer_norm_before = do_layer_norm_before
|
self.do_layer_norm_before = do_layer_norm_before
|
||||||
|
|
||||||
|
# Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility
|
||||||
|
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||||
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||||
|
self._remove_final_layer_norm = _remove_final_layer_norm
|
||||||
|
|||||||
@@ -37,8 +37,6 @@ def load_checkpoint(checkpoint_path):
|
|||||||
# pop unnecessary weights
|
# pop unnecessary weights
|
||||||
keys_to_delete = [
|
keys_to_delete = [
|
||||||
"decoder.version",
|
"decoder.version",
|
||||||
"decoder.layer_norm.weight",
|
|
||||||
"decoder.layer_norm.bias",
|
|
||||||
"decoder.output_projection.weight",
|
"decoder.output_projection.weight",
|
||||||
]
|
]
|
||||||
for key in keys_to_delete:
|
for key in keys_to_delete:
|
||||||
@@ -48,6 +46,8 @@ def load_checkpoint(checkpoint_path):
|
|||||||
keys_to_rename = {
|
keys_to_rename = {
|
||||||
"decoder.project_in_dim.weight": "decoder.project_in.weight",
|
"decoder.project_in_dim.weight": "decoder.project_in.weight",
|
||||||
"decoder.project_out_dim.weight": "decoder.project_out.weight",
|
"decoder.project_out_dim.weight": "decoder.project_out.weight",
|
||||||
|
"decoder.layer_norm.weight": "decoder.final_layer_norm.weight",
|
||||||
|
"decoder.layer_norm.bias": "decoder.final_layer_norm.bias",
|
||||||
}
|
}
|
||||||
for old_key, new_key in keys_to_rename.items():
|
for old_key, new_key in keys_to_rename.items():
|
||||||
if old_key in sd:
|
if old_key in sd:
|
||||||
|
|||||||
@@ -452,6 +452,14 @@ class FlaxOPTDecoder(nn.Module):
|
|||||||
self.project_in = None
|
self.project_in = None
|
||||||
self.project_out = None
|
self.project_out = None
|
||||||
|
|
||||||
|
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||||
|
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||||
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||||
|
if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
|
||||||
|
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
||||||
|
else:
|
||||||
|
self.final_layer_norm = None
|
||||||
|
|
||||||
self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
|
self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -487,6 +495,9 @@ class FlaxOPTDecoder(nn.Module):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.final_layer_norm is not None:
|
||||||
|
hidden_state = self.final_layer_norm(hidden_state)
|
||||||
|
|
||||||
if self.project_out is not None:
|
if self.project_out is not None:
|
||||||
hidden_state = self.project_out(hidden_state)
|
hidden_state = self.project_out(hidden_state)
|
||||||
|
|
||||||
|
|||||||
@@ -492,7 +492,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
self.project_in = None
|
self.project_in = None
|
||||||
|
|
||||||
self.layer_norm = None
|
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||||
|
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||||
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||||
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
else:
|
||||||
|
self.final_layer_norm = None
|
||||||
|
|
||||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@@ -688,6 +695,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if self.final_layer_norm is not None:
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
|
||||||
if self.project_out is not None:
|
if self.project_out is not None:
|
||||||
hidden_states = self.project_out(hidden_states)
|
hidden_states = self.project_out(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@@ -506,6 +506,14 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
|||||||
name="embed_positions",
|
name="embed_positions",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||||
|
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||||
|
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||||
|
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||||
|
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||||
|
else:
|
||||||
|
self.final_layer_norm = None
|
||||||
|
|
||||||
if config.word_embed_proj_dim != config.hidden_size:
|
if config.word_embed_proj_dim != config.hidden_size:
|
||||||
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
|
self.project_out = tf.keras.layers.Dense(config.word_embed_proj_dim, name="project_out", use_bias=False)
|
||||||
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
|
self.project_in = tf.keras.layers.Dense(config.hidden_size, name="project_in", use_bias=False)
|
||||||
@@ -681,6 +689,9 @@ class TFOPTDecoder(tf.keras.layers.Layer):
|
|||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
|
|
||||||
|
if self.final_layer_norm is not None:
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
|
||||||
if self.project_out is not None:
|
if self.project_out is not None:
|
||||||
hidden_states = self.project_out(hidden_states)
|
hidden_states = self.project_out(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@@ -292,10 +292,10 @@ class FlaxOPTGenerationTest(unittest.TestCase):
|
|||||||
model_id = "facebook/opt-125m"
|
model_id = "facebook/opt-125m"
|
||||||
|
|
||||||
EXPECTED_OUTPUTS = [
|
EXPECTED_OUTPUTS = [
|
||||||
"Today is a beautiful day and I want everyone",
|
"Today is a beautiful day and I want to",
|
||||||
"In the city of Rome Canaver Canaver Canaver Canaver",
|
"In the city of New York, the city",
|
||||||
"Paris is the capital of France and Parisdylib",
|
"Paris is the capital of France and the capital",
|
||||||
"Computers and mobile phones have taken precedence over",
|
"Computers and mobile phones have taken over the",
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_outputs = []
|
predicted_outputs = []
|
||||||
|
|||||||
@@ -344,10 +344,10 @@ class OPTGenerationTest(unittest.TestCase):
|
|||||||
model_id = "facebook/opt-125m"
|
model_id = "facebook/opt-125m"
|
||||||
|
|
||||||
EXPECTED_OUTPUTS = [
|
EXPECTED_OUTPUTS = [
|
||||||
"Today is a beautiful day and I want everyone",
|
"Today is a beautiful day and I want to",
|
||||||
"In the city of Rome Canaver Canaver Canaver Canaver",
|
"In the city of New York, the city",
|
||||||
"Paris is the capital of France and Parisdylib",
|
"Paris is the capital of France and the capital",
|
||||||
"Computers and mobile phones have taken precedence over",
|
"Computers and mobile phones have taken over the",
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_outputs = []
|
predicted_outputs = []
|
||||||
|
|||||||
@@ -330,10 +330,10 @@ class TFOPTGenerationTest(unittest.TestCase):
|
|||||||
model_id = "facebook/opt-125m"
|
model_id = "facebook/opt-125m"
|
||||||
|
|
||||||
EXPECTED_OUTPUTS = [
|
EXPECTED_OUTPUTS = [
|
||||||
"Today is a beautiful day and I want everyone",
|
"Today is a beautiful day and I want to",
|
||||||
"In the city of Rome Canaver Canaver Canaver Canaver",
|
"In the city of New York, the city",
|
||||||
"Paris is the capital of France and Parisdylib",
|
"Paris is the capital of France and the capital",
|
||||||
"Computers and mobile phones have taken precedence over",
|
"Computers and mobile phones have taken over the",
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_outputs = []
|
predicted_outputs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user