Making TF TransfoXL model compliant with AMP (#10264)
* Fix AMP * Apply style * Remove unused import
This commit is contained in:
@@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
|
|||||||
self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))
|
self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))
|
||||||
|
|
||||||
def call(self, pos_seq, bsz=None):
|
def call(self, pos_seq, bsz=None):
|
||||||
|
self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype)
|
||||||
sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq)
|
sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq)
|
||||||
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
|
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
|
||||||
|
|
||||||
@@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
|
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
|
||||||
|
|
||||||
if mems is not None:
|
if mems is not None:
|
||||||
|
mems = tf.cast(mems, dtype=w.dtype)
|
||||||
cat = tf.concat([mems, w], 0)
|
cat = tf.concat([mems, w], 0)
|
||||||
if self.pre_lnorm:
|
if self.pre_lnorm:
|
||||||
w_heads = self.qkv_net(self.layer_norm(cat))
|
w_heads = self.qkv_net(self.layer_norm(cat))
|
||||||
@@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
# compute attention probability
|
# compute attention probability
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask_t = attn_mask[:, :, None, None]
|
attn_mask_t = attn_mask[:, :, None, None]
|
||||||
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype)
|
||||||
|
attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t
|
||||||
|
|
||||||
# [qlen x klen x bsz x n_head]
|
# [qlen x klen x bsz x n_head]
|
||||||
attn_prob = tf.nn.softmax(attn_score, axis=1)
|
attn_prob = tf.nn.softmax(attn_score, axis=1)
|
||||||
@@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class TFTransfoEmbeddings(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, vocab_size, emb_size, init_std, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.emb_size = emb_size
|
||||||
|
self.init_std = init_std
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.weight = self.add_weight(
|
||||||
|
shape=(self.vocab_size, self.emb_size),
|
||||||
|
initializer=get_initializer(self.init_std),
|
||||||
|
name="embeddings",
|
||||||
|
)
|
||||||
|
|
||||||
|
super().build(input_shape)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return tf.gather(self.weight, inputs)
|
||||||
|
|
||||||
|
|
||||||
class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
||||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs):
|
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
self.emb_layers = []
|
self.emb_layers = []
|
||||||
self.emb_projs = []
|
self.emb_projs = []
|
||||||
|
|
||||||
if div_val == 1:
|
if div_val == 1:
|
||||||
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
||||||
else:
|
else:
|
||||||
@@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||||
d_emb_i = d_embed // (div_val ** i)
|
d_emb_i = d_embed // (div_val ** i)
|
||||||
self.emb_layers.append(
|
self.emb_layers.append(
|
||||||
tf.keras.layers.Embedding(
|
TFTransfoEmbeddings(
|
||||||
r_idx - l_idx,
|
r_idx - l_idx,
|
||||||
d_emb_i,
|
d_emb_i,
|
||||||
embeddings_initializer=get_initializer(init_std),
|
init_std,
|
||||||
name="emb_layers_._{}".format(i),
|
name="emb_layers_._{}".format(i),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
name="emb_projs_._{}".format(i),
|
name="emb_projs_._{}".format(i),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, inp):
|
def call(self, inp):
|
||||||
@@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
emb_i = self.emb_layers[i](inp_i)
|
emb_i = self.emb_layers[i](inp_i)
|
||||||
emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i])
|
emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i])
|
||||||
|
|
||||||
mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64)
|
mask_idx = tf.where(mask_i)
|
||||||
emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(shape_list(emb_flat), dtype=tf.int64))
|
scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat))
|
||||||
|
emb_flat = tf.cast(emb_flat, dtype=scatter.dtype)
|
||||||
|
emb_flat += scatter
|
||||||
|
|
||||||
embed_shape = shape_list(inp) + [self.d_proj]
|
embed_shape = shape_list(inp) + [self.d_proj]
|
||||||
embed = tf.reshape(emb_flat, embed_shape)
|
embed = tf.reshape(emb_flat, embed_shape)
|
||||||
@@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
end_idx = mlen + tf.math.maximum(0, qlen)
|
end_idx = mlen + tf.math.maximum(0, qlen)
|
||||||
beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
|
beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
|
||||||
for i in range(len(hids)):
|
for i in range(len(hids)):
|
||||||
|
mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)
|
||||||
cat = tf.concat([mems[i], hids[i]], axis=0)
|
cat = tf.concat([mems[i], hids[i]], axis=0)
|
||||||
tf.stop_gradient(cat)
|
tf.stop_gradient(cat)
|
||||||
new_mems.append(cat[beg_idx:end_idx])
|
new_mems.append(cat[beg_idx:end_idx])
|
||||||
@@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
|
|||||||
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
logits = self.score(hidden_states)
|
logits = self.score(hidden_states)
|
||||||
logits_shape = shape_list(logits)
|
|
||||||
in_logits = None
|
in_logits = None
|
||||||
if self.config.pad_token_id is None:
|
if self.config.pad_token_id is None:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
@@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
|
|||||||
if inputs["input_ids"] is not None:
|
if inputs["input_ids"] is not None:
|
||||||
sequence_lengths = (
|
sequence_lengths = (
|
||||||
tf.reduce_sum(
|
tf.reduce_sum(
|
||||||
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
|
tf.cast(
|
||||||
|
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
|
||||||
|
dtype=inputs["input_ids"].dtype,
|
||||||
|
),
|
||||||
-1,
|
-1,
|
||||||
keepdims=False,
|
keepdims=False,
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||||
def get_seq_element(sequence_position, input_batch):
|
|
||||||
return tf.strided_slice(
|
|
||||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = tf.map_fn(
|
|
||||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
|
||||||
)
|
|
||||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
hidden_sizes = shape_list(hidden)
|
hidden_sizes = shape_list(hidden)
|
||||||
out = []
|
out = []
|
||||||
loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
|
loss = tf.zeros(hidden_sizes[:2])
|
||||||
for i in range(len(self.cutoffs)):
|
for i in range(len(self.cutoffs)):
|
||||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||||
if target is not None:
|
if target is not None:
|
||||||
@@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
|||||||
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
|
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
|
||||||
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
|
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
|
||||||
if target is not None:
|
if target is not None:
|
||||||
loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64))
|
loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss))
|
||||||
out = tf.concat(out, axis=-1)
|
out = tf.concat(out, axis=-1)
|
||||||
|
|
||||||
if target is not None:
|
if target is not None:
|
||||||
|
|||||||
@@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
name = model.get_bias()
|
name = model.get_bias()
|
||||||
assert name is None
|
assert name is None
|
||||||
|
|
||||||
def test_mixed_precision(self):
|
|
||||||
# TODO JP: Make TransfoXL float16 compliant
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_xla_mode(self):
|
def test_xla_mode(self):
|
||||||
# TODO JP: Make TransfoXL XLA compliant
|
# TODO JP: Make TransfoXL XLA compliant
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user