Making TF GPT2 compliant with XLA and AMP (#10230)
* Fix XLA and AMP * Fix AMP and XLA * Apply style * Apply Patrick's comment
This commit is contained in:
@@ -112,6 +112,7 @@ class TFAttention(tf.keras.layers.Layer):
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
|
||||
w = w + attention_mask
|
||||
|
||||
w = tf.nn.softmax(w, axis=-1)
|
||||
@@ -224,20 +225,26 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
self.num_hidden_layers = config.n_layer
|
||||
self.vocab_size = config.vocab_size
|
||||
self.n_embd = config.n_embd
|
||||
self.n_positions = config.n_positions
|
||||
self.initializer_range = config.initializer_range
|
||||
|
||||
self.wte = TFSharedEmbeddings(
|
||||
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
|
||||
)
|
||||
self.wpe = tf.keras.layers.Embedding(
|
||||
config.n_positions,
|
||||
config.n_embd,
|
||||
embeddings_initializer=get_initializer(config.initializer_range),
|
||||
name="wpe",
|
||||
)
|
||||
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
||||
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
|
||||
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
|
||||
|
||||
def build(self, input_shape):
|
||||
with tf.name_scope("wpe"):
|
||||
self.wpe = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.n_positions, self.n_embd],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
@@ -302,9 +309,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
past_length = shape_list(inputs["past"][0][0])[-2]
|
||||
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.expand_dims(
|
||||
tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0
|
||||
)
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
|
||||
|
||||
if inputs["attention_mask"] is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
@@ -322,11 +327,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
||||
else:
|
||||
inputs["attention_mask"] = None
|
||||
one_cst = tf.constant(1.0)
|
||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||
inputs["attention_mask"] = tf.multiply(
|
||||
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -344,7 +349,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
|
||||
|
||||
position_embeds = self.wpe(inputs["position_ids"])
|
||||
position_embeds = tf.gather(self.wpe, inputs["position_ids"])
|
||||
|
||||
if inputs["token_type_ids"] is not None:
|
||||
inputs["token_type_ids"] = tf.reshape(
|
||||
@@ -352,7 +357,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
)
|
||||
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
token_type_embeds = tf.constant(0.0)
|
||||
|
||||
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
|
||||
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
|
||||
@@ -1024,7 +1029,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
if inputs["input_ids"] is not None:
|
||||
sequence_lengths = (
|
||||
tf.reduce_sum(
|
||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
||||
tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||
dtype=inputs["input_ids"].dtype,
|
||||
),
|
||||
-1,
|
||||
keepdims=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user