Making TF GPT2 compliant with XLA and AMP (#10230)
* Fix XLA and AMP * Fix AMP and XLA * Apply style * Apply Patrick's comment
This commit is contained in:
@@ -1331,119 +1331,6 @@ class TFConv1D(tf.keras.layers.Layer):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class WordEmbeddings(tf.keras.layers.Layer):
|
|
||||||
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
|
|
||||||
def build(self, input_shape):
|
|
||||||
self.word_embeddings = self.add_weight(
|
|
||||||
name="weight",
|
|
||||||
shape=[self.vocab_size, self.hidden_size],
|
|
||||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
|
||||||
)
|
|
||||||
|
|
||||||
super().build(input_shape=input_shape)
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {
|
|
||||||
"vocab_size": self.vocab_size,
|
|
||||||
"hidden_size": self.hidden_size,
|
|
||||||
"initializer_range": self.initializer_range,
|
|
||||||
}
|
|
||||||
base_config = super().get_config()
|
|
||||||
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
|
|
||||||
def call(self, input_ids):
|
|
||||||
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
|
|
||||||
embeddings = tf.gather(params=self.word_embeddings, indices=flat_input_ids)
|
|
||||||
embeddings = tf.reshape(
|
|
||||||
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class TokenTypeEmbeddings(tf.keras.layers.Layer):
|
|
||||||
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
|
|
||||||
def build(self, input_shape):
|
|
||||||
self.token_type_embeddings = self.add_weight(
|
|
||||||
name="embeddings",
|
|
||||||
shape=[self.type_vocab_size, self.hidden_size],
|
|
||||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
|
||||||
)
|
|
||||||
|
|
||||||
super().build(input_shape=input_shape)
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {
|
|
||||||
"type_vocab_size": self.type_vocab_size,
|
|
||||||
"hidden_size": self.hidden_size,
|
|
||||||
"initializer_range": self.initializer_range,
|
|
||||||
}
|
|
||||||
base_config = super().get_config()
|
|
||||||
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
|
|
||||||
def call(self, token_type_ids):
|
|
||||||
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
|
|
||||||
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
|
|
||||||
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
|
|
||||||
embeddings = tf.reshape(
|
|
||||||
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddings(tf.keras.layers.Layer):
|
|
||||||
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
|
|
||||||
def build(self, input_shape):
|
|
||||||
self.position_embeddings = self.add_weight(
|
|
||||||
name="embeddings",
|
|
||||||
shape=[self.max_position_embeddings, self.hidden_size],
|
|
||||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
|
||||||
)
|
|
||||||
|
|
||||||
super().build(input_shape)
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {
|
|
||||||
"max_position_embeddings": self.max_position_embeddings,
|
|
||||||
"hidden_size": self.hidden_size,
|
|
||||||
"initializer_range": self.initializer_range,
|
|
||||||
}
|
|
||||||
base_config = super().get_config()
|
|
||||||
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
|
|
||||||
def call(self, position_ids):
|
|
||||||
input_shape = shape_list(tensor=position_ids)
|
|
||||||
position_embeddings = self.position_embeddings[: input_shape[1], :]
|
|
||||||
|
|
||||||
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
|
|
||||||
|
|
||||||
|
|
||||||
class TFSharedEmbeddings(tf.keras.layers.Layer):
|
class TFSharedEmbeddings(tf.keras.layers.Layer):
|
||||||
r"""
|
r"""
|
||||||
Construct shared token embeddings.
|
Construct shared token embeddings.
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
|
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
|
||||||
w = w + attention_mask
|
w = w + attention_mask
|
||||||
|
|
||||||
w = tf.nn.softmax(w, axis=-1)
|
w = tf.nn.softmax(w, axis=-1)
|
||||||
@@ -224,20 +225,26 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
self.num_hidden_layers = config.n_layer
|
self.num_hidden_layers = config.n_layer
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
|
self.n_positions = config.n_positions
|
||||||
|
self.initializer_range = config.initializer_range
|
||||||
|
|
||||||
self.wte = TFSharedEmbeddings(
|
self.wte = TFSharedEmbeddings(
|
||||||
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
|
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
|
||||||
)
|
)
|
||||||
self.wpe = tf.keras.layers.Embedding(
|
|
||||||
config.n_positions,
|
|
||||||
config.n_embd,
|
|
||||||
embeddings_initializer=get_initializer(config.initializer_range),
|
|
||||||
name="wpe",
|
|
||||||
)
|
|
||||||
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
||||||
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
|
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
|
||||||
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
|
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_f")
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
with tf.name_scope("wpe"):
|
||||||
|
self.wpe = self.add_weight(
|
||||||
|
name="embeddings",
|
||||||
|
shape=[self.n_positions, self.n_embd],
|
||||||
|
initializer=get_initializer(self.initializer_range),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.wte
|
return self.wte
|
||||||
|
|
||||||
@@ -302,9 +309,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
past_length = shape_list(inputs["past"][0][0])[-2]
|
past_length = shape_list(inputs["past"][0][0])[-2]
|
||||||
|
|
||||||
if inputs["position_ids"] is None:
|
if inputs["position_ids"] is None:
|
||||||
inputs["position_ids"] = tf.expand_dims(
|
inputs["position_ids"] = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
|
||||||
tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None:
|
if inputs["attention_mask"] is not None:
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
@@ -322,11 +327,11 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
# positions we want to attend and -10000.0 for masked positions.
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
|
one_cst = tf.constant(1.0)
|
||||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
inputs["attention_mask"] = tf.multiply(
|
||||||
else:
|
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
|
||||||
inputs["attention_mask"] = None
|
)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@@ -344,7 +349,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
if inputs["inputs_embeds"] is None:
|
if inputs["inputs_embeds"] is None:
|
||||||
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
|
inputs["inputs_embeds"] = self.wte(inputs["input_ids"], mode="embedding")
|
||||||
|
|
||||||
position_embeds = self.wpe(inputs["position_ids"])
|
position_embeds = tf.gather(self.wpe, inputs["position_ids"])
|
||||||
|
|
||||||
if inputs["token_type_ids"] is not None:
|
if inputs["token_type_ids"] is not None:
|
||||||
inputs["token_type_ids"] = tf.reshape(
|
inputs["token_type_ids"] = tf.reshape(
|
||||||
@@ -352,7 +357,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
|
token_type_embeds = self.wte(inputs["token_type_ids"], mode="embedding")
|
||||||
else:
|
else:
|
||||||
token_type_embeds = 0
|
token_type_embeds = tf.constant(0.0)
|
||||||
|
|
||||||
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
|
position_embeds = tf.cast(position_embeds, dtype=inputs["inputs_embeds"].dtype)
|
||||||
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
|
token_type_embeds = tf.cast(token_type_embeds, dtype=inputs["inputs_embeds"].dtype)
|
||||||
@@ -1024,7 +1029,10 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
|||||||
if inputs["input_ids"] is not None:
|
if inputs["input_ids"] is not None:
|
||||||
sequence_lengths = (
|
sequence_lengths = (
|
||||||
tf.reduce_sum(
|
tf.reduce_sum(
|
||||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
tf.cast(
|
||||||
|
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||||
|
dtype=inputs["input_ids"].dtype,
|
||||||
|
),
|
||||||
-1,
|
-1,
|
||||||
keepdims=False,
|
keepdims=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -389,14 +389,6 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)
|
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
# TODO JP: Make GPT2 float16 compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_xla_mode(self):
|
|
||||||
# TODO JP: Make GPT2 XLA compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user