Making TF XLM-like models XLA and AMP compliant (#10211)
* Fix Flaubert and XLM * Remove useless cast * Tiny fix * Tiny fix
This commit is contained in:
@@ -171,7 +171,7 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
||||
def get_masks(slen, lengths, causal, padding_mask=None):
|
||||
"""
|
||||
Generate hidden states mask, and optionally an attention mask.
|
||||
"""
|
||||
@@ -193,12 +193,10 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
||||
|
||||
# sanity check
|
||||
# assert shape_list(mask) == [bs, slen]
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
||||
|
||||
mask = tf.cast(mask, dtype=dtype)
|
||||
attn_mask = tf.cast(attn_mask, dtype=dtype)
|
||||
|
||||
return mask, attn_mask
|
||||
|
||||
|
||||
@@ -339,8 +337,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
|
||||
klen = shape_list(kv)[1]
|
||||
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
dim_per_head = tf.math.divide(self.dim, self.n_heads)
|
||||
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
|
||||
|
||||
def shape(x):
|
||||
@@ -372,8 +369,8 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
|
||||
|
||||
cache[self.layer_id] = (k, v)
|
||||
|
||||
q = tf.cast(q, dtype=tf.float32)
|
||||
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
|
||||
f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype)
|
||||
q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head)
|
||||
k = tf.cast(k, dtype=q.dtype)
|
||||
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
|
||||
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
|
||||
@@ -438,22 +435,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.return_dict = config.use_return_dict
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.embed_init_std = config.embed_init_std
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
self.dim,
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="position_embeddings",
|
||||
)
|
||||
|
||||
if config.n_langs > 1 and config.use_lang_emb:
|
||||
self.lang_embeddings = tf.keras.layers.Embedding(
|
||||
self.n_langs,
|
||||
self.dim,
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="lang_embeddings",
|
||||
)
|
||||
|
||||
self.embeddings = TFSharedEmbeddings(
|
||||
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
|
||||
)
|
||||
@@ -482,6 +466,24 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i))
|
||||
)
|
||||
|
||||
def build(self, input_shape):
|
||||
with tf.name_scope("position_embeddings"):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.dim],
|
||||
initializer=get_initializer(self.embed_init_std),
|
||||
)
|
||||
|
||||
if self.n_langs > 1 and self.use_lang_emb:
|
||||
with tf.name_scope("lang_embeddings"):
|
||||
self.lang_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.n_langs, self.dim],
|
||||
initializer=get_initializer(self.embed_init_std),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
@@ -538,14 +540,15 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
if inputs["lengths"] is None:
|
||||
if inputs["input_ids"] is not None:
|
||||
inputs["lengths"] = tf.reduce_sum(
|
||||
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
|
||||
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1
|
||||
)
|
||||
else:
|
||||
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
|
||||
inputs["lengths"] = tf.convert_to_tensor([slen] * bs)
|
||||
# mask = input_ids != self.pad_index
|
||||
|
||||
# check inputs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["lengths"])[0], bs
|
||||
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
|
||||
@@ -564,7 +567,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
# position_ids
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
|
||||
else:
|
||||
inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1))
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["position_ids"]), [bs, slen]
|
||||
@@ -572,7 +577,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if inputs["langs"] is not None:
|
||||
if inputs["langs"] is not None and tf.executing_eagerly():
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["langs"]), [bs, slen]
|
||||
@@ -603,15 +608,16 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer):
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
|
||||
|
||||
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
|
||||
tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"])
|
||||
|
||||
if inputs["langs"] is not None and self.use_lang_emb:
|
||||
tensor = tensor + self.lang_embeddings(inputs["langs"])
|
||||
tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"])
|
||||
if inputs["token_type_ids"] is not None:
|
||||
tensor = tensor + self.embeddings(inputs["token_type_ids"])
|
||||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=inputs["training"])
|
||||
mask = tf.cast(mask, dtype=tensor.dtype)
|
||||
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
||||
|
||||
# hidden_states and attentions cannot be None in graph mode.
|
||||
@@ -804,7 +810,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel):
|
||||
lang_id = self.config.lang_id
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id
|
||||
mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id
|
||||
inputs = tf.concat([inputs, mask_token], axis=1)
|
||||
|
||||
if lang_id is not None:
|
||||
|
||||
@@ -82,7 +82,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
out[:, 1::2] = tf.constant(np.cos(position_enc[:, 1::2]))
|
||||
|
||||
|
||||
def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
||||
def get_masks(slen, lengths, causal, padding_mask=None):
|
||||
"""
|
||||
Generate hidden states mask, and optionally an attention mask.
|
||||
"""
|
||||
@@ -104,12 +104,10 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
||||
|
||||
# sanity check
|
||||
# assert shape_list(mask) == [bs, slen]
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
||||
|
||||
mask = tf.cast(mask, dtype=dtype)
|
||||
attn_mask = tf.cast(attn_mask, dtype=dtype)
|
||||
|
||||
return mask, attn_mask
|
||||
|
||||
|
||||
@@ -148,8 +146,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
|
||||
klen = shape_list(kv)[1]
|
||||
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
dim_per_head = tf.math.divide(self.dim, self.n_heads)
|
||||
dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
|
||||
|
||||
def shape(x):
|
||||
@@ -181,8 +178,8 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
|
||||
|
||||
cache[self.layer_id] = (k, v)
|
||||
|
||||
q = tf.cast(q, dtype=tf.float32)
|
||||
q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head)
|
||||
f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype)
|
||||
q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head)
|
||||
k = tf.cast(k, dtype=q.dtype)
|
||||
scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
|
||||
mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
|
||||
@@ -263,30 +260,18 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
self.hidden_dim = self.dim * 4 # 2048 by default
|
||||
self.n_heads = config.n_heads # 8 by default
|
||||
self.n_layers = config.n_layers
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.embed_init_std = config.embed_init_std
|
||||
assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
|
||||
|
||||
# embeddings
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout)
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
self.dim,
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="position_embeddings",
|
||||
)
|
||||
|
||||
if config.sinusoidal_embeddings:
|
||||
raise NotImplementedError
|
||||
# create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
|
||||
|
||||
if config.n_langs > 1 and config.use_lang_emb:
|
||||
self.lang_embeddings = tf.keras.layers.Embedding(
|
||||
self.n_langs,
|
||||
self.dim,
|
||||
embeddings_initializer=get_initializer(config.embed_init_std),
|
||||
name="lang_embeddings",
|
||||
)
|
||||
|
||||
self.embeddings = TFSharedEmbeddings(
|
||||
self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings"
|
||||
) # padding_idx=self.pad_index)
|
||||
@@ -326,6 +311,24 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
if self.attentions[int(layer)].n_heads == config.n_heads:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
def build(self, input_shape):
|
||||
with tf.name_scope("position_embeddings"):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.dim],
|
||||
initializer=get_initializer(self.embed_init_std),
|
||||
)
|
||||
|
||||
if self.n_langs > 1 and self.use_lang_emb:
|
||||
with tf.name_scope("lang_embeddings"):
|
||||
self.lang_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.n_langs, self.dim],
|
||||
initializer=get_initializer(self.embed_init_std),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
@@ -389,14 +392,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
if inputs["lengths"] is None:
|
||||
if inputs["input_ids"] is not None:
|
||||
inputs["lengths"] = tf.reduce_sum(
|
||||
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1
|
||||
tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1
|
||||
)
|
||||
else:
|
||||
inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32)
|
||||
inputs["lengths"] = tf.convert_to_tensor([slen] * bs)
|
||||
# mask = input_ids != self.pad_index
|
||||
|
||||
# check inputs
|
||||
# assert shape_list(lengths)[0] == bs
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["lengths"])[0], bs
|
||||
), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched"
|
||||
@@ -415,7 +419,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# position_ids
|
||||
if inputs["position_ids"] is None:
|
||||
inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0)
|
||||
else:
|
||||
inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1))
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["position_ids"]), [bs, slen]
|
||||
@@ -423,7 +429,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
||||
# langs
|
||||
if inputs["langs"] is not None:
|
||||
if inputs["langs"] is not None and tf.executing_eagerly():
|
||||
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["langs"]), [bs, slen]
|
||||
@@ -454,15 +460,16 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"])
|
||||
|
||||
tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"])
|
||||
tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"])
|
||||
|
||||
if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1:
|
||||
tensor = tensor + self.lang_embeddings(inputs["langs"])
|
||||
tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"])
|
||||
if inputs["token_type_ids"] is not None:
|
||||
tensor = tensor + self.embeddings(inputs["token_type_ids"])
|
||||
|
||||
tensor = self.layer_norm_emb(tensor)
|
||||
tensor = self.dropout(tensor, training=inputs["training"])
|
||||
mask = tf.cast(mask, dtype=tensor.dtype)
|
||||
tensor = tensor * tf.expand_dims(mask, axis=-1)
|
||||
|
||||
# transformer layers
|
||||
@@ -837,7 +844,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
|
||||
lang_id = self.config.lang_id
|
||||
|
||||
effective_batch_size = inputs.shape[0]
|
||||
mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id
|
||||
mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id
|
||||
inputs = tf.concat([inputs, mask_token], axis=1)
|
||||
|
||||
if lang_id is not None:
|
||||
|
||||
@@ -331,14 +331,6 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
model = TFFlaubertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make Flaubert float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make Flaubert XLA compliant
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
|
||||
@@ -327,14 +327,6 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make XLM float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make XLM XLA compliant
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
Reference in New Issue
Block a user