Making TF GPT2 compliant with XLA and AMP (#10230)

* Fix XLA and AMP

* Fix AMP and XLA

* Apply style

* Apply Patrick's comment
This commit is contained in:
Julien Plu
2021-02-18 09:36:01 +01:00
committed by GitHub
parent 5da7c78ed8
commit bdf1669e3f
3 changed files with 25 additions and 138 deletions

View File

@@ -112,6 +112,7 @@ class TFAttention(tf.keras.layers.Layer):
if attention_mask is not None:
# Apply the attention mask
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
w = w + attention_mask
w = tf.nn.softmax(w, axis=-1)
@@ -224,20 +225,26 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.num_hidden_layers = config.n_layer
self.vocab_size = config.vocab_size
self.n_embd = config.n_embd
self.n_positions = config.n_positions
self.initializer_range = config.initializer_range
self.wte = TFSharedEmbeddings(
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
)
self.wpe = tf.keras.layers.Embedding(
config.n_positions,
config.n_embd,
embeddings_initializer=get_initializer(config.initializer_range),
name="wpe",
)
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
def build(self, input_shape):
with tf.name_scope("wpe"):
self.wpe = self.add_weight(
name="embeddings",
shape=[self.n_positions, self.n_embd],
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
def get_input_embeddings(self):
return self.wte
@@ -302,9 +309,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
past_length = shape_list(inputs["past"][0][0])[-2]
if inputs["position_ids"] is None:
inputs["position_ids"] = tf.expand_dims(
tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0
)
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
if inputs["attention_mask"] is not None:
# We create a 3D attention mask from a 2D tensor mask.
@@ -322,11 +327,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
else:
inputs["attention_mask"] = None
one_cst = tf.constant(1.0)
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
inputs["attention_mask"] = tf.multiply(
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -344,7 +349,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
position_embeds = self.wpe(inputs["position_ids"])
position_embeds = tf.gather(self.wpe, inputs["position_ids"])
if inputs["token_type_ids"] is not None:
inputs["token_type_ids"] = tf.reshape(
@@ -352,7 +357,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
)
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
else:
token_type_embeds = 0
token_type_embeds = tf.constant(0.0)
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
@@ -1024,7 +1029,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
if inputs["input_ids"] is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
),
-1,
keepdims=False,
)