From db136341836d6d9a4d599e2935253e0ae04cc1f1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 18 May 2023 14:46:40 +0100 Subject: [PATCH] TF: GPT2 with native embedding layers (#23436) --- docs/source/en/internal/modeling_utils.mdx | 3 -- src/transformers/modeling_tf_utils.py | 4 ++ .../models/gpt2/modeling_tf_gpt2.py | 39 +++++++++---------- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/docs/source/en/internal/modeling_utils.mdx b/docs/source/en/internal/modeling_utils.mdx index 914b8ca367..578740df02 100644 --- a/docs/source/en/internal/modeling_utils.mdx +++ b/docs/source/en/internal/modeling_utils.mdx @@ -54,9 +54,6 @@ Most of those are only useful if you are studying the code of the models in the [[autodoc]] modeling_tf_utils.TFConv1D -[[autodoc]] modeling_tf_utils.TFSharedEmbeddings - - call - [[autodoc]] modeling_tf_utils.TFSequenceSummary ## TensorFlow loss functions diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 5618472744..630290d921 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -3132,6 +3132,10 @@ class TFSharedEmbeddings(tf.keras.layers.Layer): self.vocab_size = vocab_size self.hidden_size = hidden_size self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range + warnings.warn( + "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `tf.keras.layers.Embedding` instead.", + DeprecationWarning, + ) def build(self, input_shape): """ diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index d0c731878d..b7cb1b6df2 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -34,7 +34,6 @@ from ...modeling_tf_utils import ( TFPreTrainedModel, TFSequenceClassificationLoss, TFSequenceSummary, - TFSharedEmbeddings, get_initializer, keras_serializable, unpack_inputs, @@ -315,29 +314,27 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): self.n_positions = config.n_positions self.initializer_range = config.initializer_range - self.wte = TFSharedEmbeddings( - config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte" + self.wte = tf.keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="wte", + ) + self.wpe = tf.keras.layers.Embedding( + input_dim=config.n_positions, + output_dim=config.n_embd, + embeddings_initializer=get_initializer(config.initializer_range), + name="wpe", ) self.drop = tf.keras.layers.Dropout(config.embd_pdrop) self.h = [TFBlock(config, scale=True, name=f"h_._{i}") for i in range(config.n_layer)] self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f") - def build(self, input_shape): - with tf.name_scope("wpe"): - self.wpe = self.add_weight( - name="embeddings", - shape=[self.n_positions, self.n_embd], - initializer=get_initializer(self.initializer_range), - ) - - super().build(input_shape) - def get_input_embeddings(self): return self.wte - def set_input_embeddings(self, value): - self.wte.weight = value - self.wte.vocab_size = shape_list(value)[0] + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings def _prune_heads(self, heads_to_prune): """ @@ -438,13 +435,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): if inputs_embeds is None: check_embeddings_within_bounds(input_ids, self.config.vocab_size) - inputs_embeds = self.wte(input_ids, mode="embedding") + inputs_embeds = self.wte(input_ids) - position_embeds = tf.gather(self.wpe, position_ids) + position_embeds = self.wpe(position_ids) if token_type_ids is not None: token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) - token_type_embeds = self.wte(token_type_ids, mode="embedding") + token_type_embeds = self.wte(token_type_ids) else: token_type_embeds = tf.constant(0.0) @@ -904,7 +901,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): training=training, ) hidden_states = transformer_outputs[0] - logits = self.transformer.wte(hidden_states, mode="linear") + logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True) loss = None if labels is not None: @@ -1048,7 +1045,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) else: all_hidden_states = None - lm_logits = self.transformer.wte(hidden_states, mode="linear") + lm_logits = tf.matmul(hidden_states, self.transformer.wte.weights, transpose_b=True) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = tf.squeeze(mc_logits, axis=-1)