From bdf1669e3f3fe942699c49f78320208a9d66572d Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 18 Feb 2021 09:36:01 +0100 Subject: [PATCH] Making TF GPT2 compliant with XLA and AMP (#10230) * Fix XLA and AMP * Fix AMP and XLA * Apply style * Apply Patrick's comment --- src/transformers/modeling_tf_utils.py | 113 ------------------ .../models/gpt2/modeling_tf_gpt2.py | 42 ++++--- tests/test_modeling_tf_gpt2.py | 8 -- 3 files changed, 25 insertions(+), 138 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e0574efbdb..8d4a28aa62 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1331,119 +1331,6 @@ class TFConv1D(tf.keras.layers.Layer): return x -class WordEmbeddings(tf.keras.layers.Layer): - def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs): - super().__init__(**kwargs) - - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.initializer_range = initializer_range - - def build(self, input_shape): - self.word_embeddings = self.add_weight( - name="weight", - shape=[self.vocab_size, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - super().build(input_shape=input_shape) - - def get_config(self): - config = { - "vocab_size": self.vocab_size, - "hidden_size": self.hidden_size, - "initializer_range": self.initializer_range, - } - base_config = super().get_config() - - return dict(list(base_config.items()) + list(config.items())) - - def call(self, input_ids): - flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1]) - embeddings = tf.gather(params=self.word_embeddings, indices=flat_input_ids) - embeddings = tf.reshape( - tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0) - ) - - embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size]) - - return embeddings - - -class TokenTypeEmbeddings(tf.keras.layers.Layer): - def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs): - super().__init__(**kwargs) - - self.type_vocab_size = type_vocab_size - self.hidden_size = hidden_size - self.initializer_range = initializer_range - - def build(self, input_shape): - self.token_type_embeddings = self.add_weight( - name="embeddings", - shape=[self.type_vocab_size, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - super().build(input_shape=input_shape) - - def get_config(self): - config = { - "type_vocab_size": self.type_vocab_size, - "hidden_size": self.hidden_size, - "initializer_range": self.initializer_range, - } - base_config = super().get_config() - - return dict(list(base_config.items()) + list(config.items())) - - def call(self, token_type_ids): - flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1]) - one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype) - embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings) - embeddings = tf.reshape( - tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0) - ) - - embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size]) - - return embeddings - - -class PositionEmbeddings(tf.keras.layers.Layer): - def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs): - super().__init__(**kwargs) - - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.initializer_range = initializer_range - - def build(self, input_shape): - self.position_embeddings = self.add_weight( - name="embeddings", - shape=[self.max_position_embeddings, self.hidden_size], - initializer=get_initializer(initializer_range=self.initializer_range), - ) - - super().build(input_shape) - - def get_config(self): - config = { - "max_position_embeddings": self.max_position_embeddings, - "hidden_size": self.hidden_size, - "initializer_range": self.initializer_range, - } - base_config = super().get_config() - - return dict(list(base_config.items()) + list(config.items())) - - def call(self, position_ids): - input_shape = shape_list(tensor=position_ids) - position_embeddings = self.position_embeddings[: input_shape[1], :] - - return tf.broadcast_to(input=position_embeddings, shape=input_shape) - - class TFSharedEmbeddings(tf.keras.layers.Layer): r""" Construct shared token embeddings. diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 4ed3257466..dc233f5d00 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -112,6 +112,7 @@ class TFAttention(tf.keras.layers.Layer): if attention_mask is not None: # Apply the attention mask + attention_mask = tf.cast(attention_mask, dtype=w.dtype) w = w + attention_mask w = tf.nn.softmax(w, axis=-1) @@ -224,20 +225,26 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): self.num_hidden_layers = config.n_layer self.vocab_size = config.vocab_size self.n_embd = config.n_embd + self.n_positions = config.n_positions + self.initializer_range = config.initializer_range self.wte = TFSharedEmbeddings( config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte" ) - self.wpe = tf.keras.layers.Embedding( - config.n_positions, - config.n_embd, - embeddings_initializer=get_initializer(config.initializer_range), - name="wpe", - ) self.drop = tf.keras.layers.Dropout(config.embd_pdrop) self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)] self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f") + def build(self, input_shape): + with tf.name_scope("wpe"): + self.wpe = self.add_weight( + name="embeddings", + shape=[self.n_positions, self.n_embd], + initializer=get_initializer(self.initializer_range), + ) + + super().build(input_shape) + def get_input_embeddings(self): return self.wte @@ -302,9 +309,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): past_length = shape_list(inputs["past"][0][0])[-2] if inputs["position_ids"] is None: - inputs["position_ids"] = tf.expand_dims( - tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0 - ) + inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) if inputs["attention_mask"] is not None: # We create a 3D attention mask from a 2D tensor mask. @@ -322,11 +327,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - - inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32) - inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0 - else: - inputs["attention_mask"] = None + one_cst = tf.constant(1.0) + inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) + inputs["attention_mask"] = tf.multiply( + tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0) + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -344,7 +349,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): if inputs["inputs_embeds"] is None: inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding") - position_embeds = self.wpe(inputs["position_ids"]) + position_embeds = tf.gather(self.wpe, inputs["position_ids"]) if inputs["token_type_ids"] is not None: inputs["token_type_ids"] = tf.reshape( @@ -352,7 +357,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ) token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding") else: - token_type_embeds = 0 + token_type_embeds = tf.constant(0.0) position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype) token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype) @@ -1024,7 +1029,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific if inputs["input_ids"] is not None: sequence_lengths = ( tf.reduce_sum( - tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), + tf.cast( + tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), + dtype=inputs["input_ids"].dtype, + ), -1, keepdims=False, ) diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 48ca4eef7f..8e13f0fdc1 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -389,14 +389,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs) - def test_mixed_precision(self): - # TODO JP: Make GPT2 float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make GPT2 XLA compliant - pass - @slow def test_model_from_pretrained(self): for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: