From 4626df5077e1baa7886b034f8dc326fef16f243a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 14 Jun 2023 14:39:02 +0100 Subject: [PATCH] TF: CTRL with native embedding layers (#23456) --- .../models/ctrl/modeling_tf_ctrl.py | 124 ++++++++++-------- tests/models/ctrl/test_modeling_tf_ctrl.py | 1 + 2 files changed, 70 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index cd68873903..d800586817 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -15,10 +15,8 @@ # limitations under the License. """ TF 2.0 CTRL model.""" - from __future__ import annotations -import warnings from typing import Optional, Tuple, Union import numpy as np @@ -30,7 +28,6 @@ from ...modeling_tf_utils import ( TFModelInputType, TFPreTrainedModel, TFSequenceClassificationLoss, - TFSharedEmbeddings, get_initializer, keras_serializable, unpack_inputs, @@ -224,8 +221,11 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size) - self.w = TFSharedEmbeddings( - config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="w" + self.w = tf.keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.n_embd, + embeddings_initializer=get_initializer(config.initializer_range), + name="w", ) self.dropout = tf.keras.layers.Dropout(config.embd_pdrop) @@ -246,9 +246,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.w - def set_input_embeddings(self, value): - self.w.weight = value - self.w.vocab_size = shape_list(value)[0] + def set_input_embeddings(self, new_embeddings): + self.w = new_embeddings def _prune_heads(self, heads_to_prune): """ @@ -308,7 +307,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) + attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1] + past_length)) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -332,15 +331,15 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): if token_type_ids is not None: token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) - token_type_embeds = self.w(token_type_ids, mode="embedding") + token_type_embeds = self.w(token_type_ids) token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype)) else: token_type_embeds = tf.constant(0.0) position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) if inputs_embeds is None: - check_embeddings_within_bounds(input_ids, self.w.vocab_size) - inputs_embeds = self.w(input_ids, mode="embedding") + check_embeddings_within_bounds(input_ids, self.w.input_dim) + inputs_embeds = self.w(input_ids) seq_len = input_shape[-1] mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0) @@ -565,39 +564,26 @@ class TFCTRLModel(TFCTRLPreTrainedModel): return outputs -class TFCTRLLMHead(tf.keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - self.config = config - # CTRL has numerical issues in XLA generate - self.supports_xla_generation = False +class TFCTRLBiasLayer(tf.keras.layers.Layer): + """ + Bias as a layer. It is used for serialization purposes: `tf.keras.Model.save_weights` stores on a per-layer basis, + so all weights have to be registered in a layer. + """ - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings + def __init__(self, shape, initializer, trainable, name, **kwargs): + super().__init__(name=name, **kwargs) + self.shape = shape + self.initializer = initializer + self.trainable = trainable - def build(self, input_shape=None): - self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + def build(self, input_shape): + self.bias = self.add_weight( + name="bias", shape=self.shape, initializer=self.initializer, trainable=self.trainable + ) super().build(input_shape) - def get_output_embeddings(self): - return self.input_embeddings - - def set_output_embeddings(self, value): - self.input_embeddings.weight = value - self.input_embeddings.vocab_size = shape_list(value)[0] - - def get_bias(self): - return {"bias": self.bias} - - def set_bias(self, value): - self.bias = value["bias"] - self.config.vocab_size = shape_list(value["bias"])[0] - - def call(self, hidden_states): - hidden_states = self.input_embeddings(hidden_states, mode="linear") - hidden_states = hidden_states + self.bias - return hidden_states + def call(self, x): + return x + self.bias @add_start_docstrings( @@ -611,24 +597,53 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.transformer = TFCTRLMainLayer(config, name="transformer") + self.bias_layer = TFCTRLBiasLayer( + name="lm_head", shape=[1, config.vocab_size], initializer="zeros", trainable=True + ) - self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head") - # CTRL has numerical issues in XLA generate - self.supports_xla_generation = False + def get_output_embeddings(self): + return self.get_input_embeddings() - def get_lm_head(self): - return self.lm_head + def set_output_embeddings(self, value): + self.set_input_embeddings(value) - def get_prefix_bias_name(self): - warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) - return self.name + "/" + self.lm_head.name + def get_bias(self): + return {"lm_head.bias": self.bias_layer.bias} - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs): + def set_bias(self, value): + # Replaces the existing layers containing bias for correct (de)serialization. + vocab_size = value["lm_head.bias"].shape[-1] + self.bias_layer = TFCTRLBiasLayer( + name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=True + ) + self.bias_layer.build(None) + self.bias_layer.bias.assign(value["lm_head.bias"]) + + # Copied from transformers.models.gpt2.modeling_tf_gpt2.TFGPT2LMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: - input_ids = tf.expand_dims(input_ids[:, -1], -1) + inputs = tf.expand_dims(inputs[:, -1], -1) + if token_type_ids is not None: + token_type_ids = tf.expand_dims(token_type_ids[:, -1], -1) - return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} + position_ids = kwargs.get("position_ids", None) + attention_mask = kwargs.get("attention_mask", None) + + if attention_mask is not None and position_ids is None: + position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + if past_key_values: + position_ids = tf.expand_dims(position_ids[:, -1], -1) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "token_type_ids": token_type_ids, + } @unpack_inputs @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @@ -672,10 +687,9 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): return_dict=return_dict, training=training, ) - hidden_states = transformer_outputs[0] - - logits = self.lm_head(hidden_states) + logits = tf.matmul(hidden_states, self.transformer.w.weights, transpose_b=True) + logits = self.bias_layer(logits) loss = None if labels is not None: diff --git a/tests/models/ctrl/test_modeling_tf_ctrl.py b/tests/models/ctrl/test_modeling_tf_ctrl.py index 4d94a97828..5c9750d4de 100644 --- a/tests/models/ctrl/test_modeling_tf_ctrl.py +++ b/tests/models/ctrl/test_modeling_tf_ctrl.py @@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase for model_class in self.all_model_classes: model = model_class(config) + model.build() # may be needed for the get_bias() call below assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) if model_class in list_lm_models: