Making TF OpenAI GPT model compliant with AMP and XLA (#10261)
* Fix AMP and XLA * Remove useless var
This commit is contained in:
@@ -81,7 +81,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def causal_attention_mask(nd, ns, dtype):
|
def causal_attention_mask(nd, ns):
|
||||||
"""
|
"""
|
||||||
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
|
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
|
||||||
-1, ns-nd), but doesn't produce garbage on TPUs.
|
-1, ns-nd), but doesn't produce garbage on TPUs.
|
||||||
@@ -89,23 +89,24 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
i = tf.range(nd)[:, None]
|
i = tf.range(nd)[:, None]
|
||||||
j = tf.range(ns)
|
j = tf.range(ns)
|
||||||
m = i >= j - ns + nd
|
m = i >= j - ns + nd
|
||||||
return tf.cast(m, dtype)
|
return m
|
||||||
|
|
||||||
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
||||||
# q, k, v have shape [batch, heads, sequence, features]
|
# q, k, v have shape [batch, heads, sequence, features]
|
||||||
w = tf.matmul(q, k, transpose_b=True)
|
w = tf.matmul(q, k, transpose_b=True)
|
||||||
if self.scale:
|
if self.scale:
|
||||||
dk = tf.cast(shape_list(k)[-1], tf.float32) # scale attention_scores
|
dk = tf.cast(shape_list(k)[-1], dtype=w.dtype) # scale attention_scores
|
||||||
w = w / tf.math.sqrt(dk)
|
w = w / tf.math.sqrt(dk)
|
||||||
|
|
||||||
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
|
||||||
_, _, nd, ns = shape_list(w)
|
_, _, nd, ns = shape_list(w)
|
||||||
b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
|
b = tf.cast(self.causal_attention_mask(nd, ns), dtype=w.dtype)
|
||||||
b = tf.reshape(b, [1, 1, nd, ns])
|
b = tf.reshape(b, [1, 1, nd, ns])
|
||||||
w = w * b - 1e4 * (1 - b)
|
w = w * b - 1e4 * (1 - b)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
|
attention_mask = tf.cast(attention_mask, dtype=w.dtype)
|
||||||
w = w + attention_mask
|
w = w + attention_mask
|
||||||
|
|
||||||
w = tf.nn.softmax(w, axis=-1)
|
w = tf.nn.softmax(w, axis=-1)
|
||||||
@@ -201,19 +202,25 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
self.num_hidden_layers = config.n_layer
|
self.num_hidden_layers = config.n_layer
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
|
self.n_positions = config.n_positions
|
||||||
|
self.initializer_range = config.initializer_range
|
||||||
|
|
||||||
self.tokens_embed = TFSharedEmbeddings(
|
self.tokens_embed = TFSharedEmbeddings(
|
||||||
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
|
config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="tokens_embed"
|
||||||
)
|
)
|
||||||
self.positions_embed = tf.keras.layers.Embedding(
|
|
||||||
config.n_positions,
|
|
||||||
config.n_embd,
|
|
||||||
embeddings_initializer=get_initializer(config.initializer_range),
|
|
||||||
name="positions_embed",
|
|
||||||
)
|
|
||||||
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
||||||
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
|
self.h = [TFBlock(config.n_ctx, config, scale=True, name="h_._{}".format(i)) for i in range(config.n_layer)]
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
with tf.name_scope("positions_embed"):
|
||||||
|
self.positions_embed = self.add_weight(
|
||||||
|
name="embeddings",
|
||||||
|
shape=[self.n_positions, self.n_embd],
|
||||||
|
initializer=get_initializer(self.initializer_range),
|
||||||
|
)
|
||||||
|
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.tokens_embed
|
return self.tokens_embed
|
||||||
|
|
||||||
@@ -268,7 +275,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
if inputs["position_ids"] is None:
|
if inputs["position_ids"] is None:
|
||||||
inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1], dtype=tf.int32), axis=0)
|
inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1]), axis=0)
|
||||||
|
|
||||||
if inputs["attention_mask"] is not None:
|
if inputs["attention_mask"] is not None:
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
@@ -284,8 +291,11 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
|
|
||||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
one_cst = tf.constant(1.0)
|
||||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||||
|
inputs["attention_mask"] = tf.multiply(
|
||||||
|
tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
inputs["attention_mask"] = None
|
inputs["attention_mask"] = None
|
||||||
|
|
||||||
@@ -304,7 +314,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if inputs["inputs_embeds"] is None:
|
if inputs["inputs_embeds"] is None:
|
||||||
inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding")
|
inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding")
|
||||||
position_embeds = self.positions_embed(inputs["position_ids"])
|
position_embeds = tf.gather(self.positions_embed, inputs["position_ids"])
|
||||||
if inputs["token_type_ids"] is not None:
|
if inputs["token_type_ids"] is not None:
|
||||||
inputs["token_type_ids"] = tf.reshape(
|
inputs["token_type_ids"] = tf.reshape(
|
||||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||||
@@ -903,7 +913,6 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
|
|||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
logits = self.score(hidden_states)
|
logits = self.score(hidden_states)
|
||||||
logits_shape = shape_list(logits)
|
|
||||||
in_logits = None
|
in_logits = None
|
||||||
if self.config.pad_token_id is None:
|
if self.config.pad_token_id is None:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
@@ -911,22 +920,16 @@ class TFOpenAIGPTForSequenceClassification(TFOpenAIGPTPreTrainedModel, TFSequenc
|
|||||||
if inputs["input_ids"] is not None:
|
if inputs["input_ids"] is not None:
|
||||||
sequence_lengths = (
|
sequence_lengths = (
|
||||||
tf.reduce_sum(
|
tf.reduce_sum(
|
||||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
tf.cast(
|
||||||
|
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||||
|
dtype=inputs["input_ids"].dtype,
|
||||||
|
),
|
||||||
-1,
|
-1,
|
||||||
keepdims=False,
|
keepdims=False,
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||||
def get_seq_element(sequence_position, input_batch):
|
|
||||||
return tf.strided_slice(
|
|
||||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = tf.map_fn(
|
|
||||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
|
||||||
)
|
|
||||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -246,14 +246,6 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs)
|
self.model_tester.create_and_check_openai_gpt_for_sequence_classification(*config_and_inputs)
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
# TODO JP: Make OpenAIGPT float16 compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_xla_mode(self):
|
|
||||||
# TODO JP: Make OpenAIGPT XLA compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user