Make TF CTRL compliant with XLA and AMP (#10209)
* Fix XLA and AMP * Apply style * Remove useless cast
This commit is contained in:
@@ -48,7 +48,7 @@ TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||||||
|
|
||||||
|
|
||||||
def angle_defn(pos, i, d_model_size):
|
def angle_defn(pos, i, d_model_size):
|
||||||
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model_size))
|
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / d_model_size)
|
||||||
return pos * angle_rates
|
return pos * angle_rates
|
||||||
|
|
||||||
|
|
||||||
@@ -58,9 +58,8 @@ def positional_encoding(position, d_model_size):
|
|||||||
|
|
||||||
sines = np.sin(angle_rads[:, 0::2])
|
sines = np.sin(angle_rads[:, 0::2])
|
||||||
cosines = np.cos(angle_rads[:, 1::2])
|
cosines = np.cos(angle_rads[:, 1::2])
|
||||||
|
pos_encoding = tf.convert_to_tensor(np.concatenate([sines, cosines], axis=-1))
|
||||||
|
|
||||||
# pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)
|
|
||||||
pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32)
|
|
||||||
return pos_encoding
|
return pos_encoding
|
||||||
|
|
||||||
|
|
||||||
@@ -68,14 +67,15 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
|
|||||||
# calculate attention
|
# calculate attention
|
||||||
matmul_qk = tf.matmul(q, k, transpose_b=True)
|
matmul_qk = tf.matmul(q, k, transpose_b=True)
|
||||||
|
|
||||||
dk = tf.cast(shape_list(k)[-1], tf.float32)
|
dk = tf.cast(shape_list(k)[-1], dtype=matmul_qk.dtype)
|
||||||
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
|
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scaled_attention_logits += mask * -1e4
|
scaled_attention_logits += tf.cast(mask * -1e4, dtype=scaled_attention_logits.dtype)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
|
attention_mask = tf.cast(attention_mask, dtype=scaled_attention_logits.dtype)
|
||||||
scaled_attention_logits = scaled_attention_logits + attention_mask
|
scaled_attention_logits = scaled_attention_logits + attention_mask
|
||||||
|
|
||||||
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
|
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
|
||||||
@@ -332,10 +332,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
|
|
||||||
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], tf.float32)
|
one_cst = tf.constant(1.0)
|
||||||
inputs["attention_mask"] = (1.0 - inputs["attention_mask"]) * -10000.0
|
ten_thousand_cst = tf.constant(-10000.0)
|
||||||
else:
|
inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
|
||||||
inputs["attention_mask"] = None
|
inputs["attention_mask"] = tf.multiply(tf.subtract(one_cst, inputs["attention_mask"]), ten_thousand_cst)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
@@ -351,9 +351,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]]
|
||||||
)
|
)
|
||||||
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
|
token_type_embeds = self.w(inputs["token_type_ids"], mode="embedding")
|
||||||
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, dtype=token_type_embeds.dtype))
|
||||||
else:
|
else:
|
||||||
token_type_embeds = 0
|
token_type_embeds = tf.constant(0.0)
|
||||||
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]])
|
||||||
|
|
||||||
if inputs["inputs_embeds"] is None:
|
if inputs["inputs_embeds"] is None:
|
||||||
@@ -361,10 +361,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
seq_len = input_shape[-1]
|
seq_len = input_shape[-1]
|
||||||
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
|
||||||
|
|
||||||
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))
|
inputs["inputs_embeds"] *= tf.math.sqrt(tf.cast(self.d_model_size, inputs["inputs_embeds"].dtype))
|
||||||
|
|
||||||
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
|
pos_embeds = tf.gather(self.pos_encoding, inputs["position_ids"])
|
||||||
|
pos_embeds = tf.cast(pos_embeds, dtype=token_type_embeds.dtype)
|
||||||
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
|
hidden_states = inputs["inputs_embeds"] + pos_embeds + token_type_embeds
|
||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
hidden_states = self.dropout(hidden_states, training=inputs["training"])
|
||||||
@@ -857,7 +857,6 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
|||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
logits = self.classifier(hidden_states)
|
logits = self.classifier(hidden_states)
|
||||||
logits_shape = shape_list(logits)
|
|
||||||
in_logits = None
|
in_logits = None
|
||||||
if self.config.pad_token_id is None:
|
if self.config.pad_token_id is None:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
@@ -865,22 +864,16 @@ class TFCTRLForSequenceClassification(TFCTRLPreTrainedModel, TFSequenceClassific
|
|||||||
if inputs["input_ids"] is not None:
|
if inputs["input_ids"] is not None:
|
||||||
sequence_lengths = (
|
sequence_lengths = (
|
||||||
tf.reduce_sum(
|
tf.reduce_sum(
|
||||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
tf.cast(
|
||||||
|
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||||
|
dtype=inputs["input_ids"].dtype,
|
||||||
|
),
|
||||||
-1,
|
-1,
|
||||||
keepdims=False,
|
keepdims=False,
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||||
def get_seq_element(sequence_position, input_batch):
|
|
||||||
return tf.strided_slice(
|
|
||||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = tf.map_fn(
|
|
||||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
|
||||||
)
|
|
||||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -222,14 +222,6 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
name = model.get_bias()
|
name = model.get_bias()
|
||||||
assert name is None
|
assert name is None
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
# TODO JP: Make CTRL float16 compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_xla_mode(self):
|
|
||||||
# TODO JP: Make CTRL XLA compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user