From fdb2351ebb55fb6fe173acaca09c4f3b4e207fdc Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 17 Feb 2021 18:02:48 +0100 Subject: [PATCH] Making TF XLM-like models XLA and AMP compliant (#10211) * Fix Flaubert and XLM * Remove useless cast * Tiny fix * Tiny fix --- .../models/flaubert/modeling_tf_flaubert.py | 76 ++++++++++--------- .../models/xlm/modeling_tf_xlm.py | 75 +++++++++--------- tests/test_modeling_tf_flaubert.py | 8 -- tests/test_modeling_tf_xlm.py | 8 -- 4 files changed, 82 insertions(+), 85 deletions(-) diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index ba712a7b0e..f3465c39f9 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -171,7 +171,7 @@ FLAUBERT_INPUTS_DOCSTRING = r""" """ -def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): +def get_masks(slen, lengths, causal, padding_mask=None): """ Generate hidden states mask, and optionally an attention mask. """ @@ -193,11 +193,9 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): # sanity check # assert shape_list(mask) == [bs, slen] - tf.debugging.assert_equal(shape_list(mask), [bs, slen]) - assert causal is False or shape_list(attn_mask) == [bs, slen, slen] - - mask = tf.cast(mask, dtype=dtype) - attn_mask = tf.cast(attn_mask, dtype=dtype) + if tf.executing_eagerly(): + tf.debugging.assert_equal(shape_list(mask), [bs, slen]) + assert causal is False or shape_list(attn_mask) == [bs, slen, slen] return mask, attn_mask @@ -339,8 +337,7 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer): klen = shape_list(kv)[1] # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) - dim_per_head = tf.math.divide(self.dim, self.n_heads) - dim_per_head = tf.cast(dim_per_head, dtype=tf.int32) + dim_per_head = self.dim // self.n_heads mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) def shape(x): @@ -372,8 +369,8 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer): cache[self.layer_id] = (k, v) - q = tf.cast(q, dtype=tf.float32) - q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head) + f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype) + q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head) k = tf.cast(k, dtype=q.dtype) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) @@ -438,22 +435,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.return_dict = config.use_return_dict + self.max_position_embeddings = config.max_position_embeddings + self.embed_init_std = config.embed_init_std self.dropout = tf.keras.layers.Dropout(config.dropout) - self.position_embeddings = tf.keras.layers.Embedding( - config.max_position_embeddings, - self.dim, - embeddings_initializer=get_initializer(config.embed_init_std), - name="position_embeddings", - ) - - if config.n_langs > 1 and config.use_lang_emb: - self.lang_embeddings = tf.keras.layers.Embedding( - self.n_langs, - self.dim, - embeddings_initializer=get_initializer(config.embed_init_std), - name="lang_embeddings", - ) - self.embeddings = TFSharedEmbeddings( self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" ) @@ -482,6 +466,24 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2_._{}".format(i)) ) + def build(self, input_shape): + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.dim], + initializer=get_initializer(self.embed_init_std), + ) + + if self.n_langs > 1 and self.use_lang_emb: + with tf.name_scope("lang_embeddings"): + self.lang_embeddings = self.add_weight( + name="embeddings", + shape=[self.n_langs, self.dim], + initializer=get_initializer(self.embed_init_std), + ) + + super().build(input_shape) + def get_input_embeddings(self): return self.embeddings @@ -538,17 +540,18 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): if inputs["lengths"] is None: if inputs["input_ids"] is not None: inputs["lengths"] = tf.reduce_sum( - tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1 + tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1 ) else: - inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32) + inputs["lengths"] = tf.convert_to_tensor([slen] * bs) # mask = input_ids != self.pad_index # check inputs # assert shape_list(lengths)[0] == bs - tf.debugging.assert_equal( - shape_list(inputs["lengths"])[0], bs - ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched" + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs["lengths"])[0], bs + ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched" # assert lengths.max().item() <= slen # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # assert (src_enc is None) == (src_len is None) @@ -564,7 +567,9 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): # position_ids if inputs["position_ids"] is None: inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0) - else: + inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1)) + + if tf.executing_eagerly(): # assert shape_list(position_ids) == [bs, slen] # (slen, bs) tf.debugging.assert_equal( shape_list(inputs["position_ids"]), [bs, slen] @@ -572,7 +577,7 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): # position_ids = position_ids.transpose(0, 1) # langs - if inputs["langs"] is not None: + if inputs["langs"] is not None and tf.executing_eagerly(): # assert shape_list(langs) == [bs, slen] # (slen, bs) tf.debugging.assert_equal( shape_list(inputs["langs"]), [bs, slen] @@ -603,15 +608,16 @@ class TFFlaubertMainLayer(tf.keras.layers.Layer): if inputs["inputs_embeds"] is None: inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"]) - tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"]) + tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"]) if inputs["langs"] is not None and self.use_lang_emb: - tensor = tensor + self.lang_embeddings(inputs["langs"]) + tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"]) if inputs["token_type_ids"] is not None: tensor = tensor + self.embeddings(inputs["token_type_ids"]) tensor = self.layer_norm_emb(tensor) tensor = self.dropout(tensor, training=inputs["training"]) + mask = tf.cast(mask, dtype=tensor.dtype) tensor = tensor * tf.expand_dims(mask, axis=-1) # hidden_states and attentions cannot be None in graph mode. @@ -804,7 +810,7 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): lang_id = self.config.lang_id effective_batch_size = inputs.shape[0] - mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id + mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id inputs = tf.concat([inputs, mask_token], axis=1) if lang_id is not None: diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index 1ecdac8d3b..b94310a2b7 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -82,7 +82,7 @@ def create_sinusoidal_embeddings(n_pos, dim, out): out[:, 1::2] = tf.constant(np.cos(position_enc[:, 1::2])) -def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): +def get_masks(slen, lengths, causal, padding_mask=None): """ Generate hidden states mask, and optionally an attention mask. """ @@ -104,11 +104,9 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): # sanity check # assert shape_list(mask) == [bs, slen] - tf.debugging.assert_equal(shape_list(mask), [bs, slen]) - assert causal is False or shape_list(attn_mask) == [bs, slen, slen] - - mask = tf.cast(mask, dtype=dtype) - attn_mask = tf.cast(attn_mask, dtype=dtype) + if tf.executing_eagerly(): + tf.debugging.assert_equal(shape_list(mask), [bs, slen]) + assert causal is False or shape_list(attn_mask) == [bs, slen, slen] return mask, attn_mask @@ -148,8 +146,7 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer): klen = shape_list(kv)[1] # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim) - dim_per_head = tf.math.divide(self.dim, self.n_heads) - dim_per_head = tf.cast(dim_per_head, dtype=tf.int32) + dim_per_head = self.dim // self.n_heads mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) def shape(x): @@ -181,8 +178,8 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer): cache[self.layer_id] = (k, v) - q = tf.cast(q, dtype=tf.float32) - q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) # (bs, n_heads, qlen, dim_per_head) + f_dim_per_head = tf.cast(dim_per_head, dtype=q.dtype) + q = tf.multiply(q, tf.math.rsqrt(f_dim_per_head)) # (bs, n_heads, qlen, dim_per_head) k = tf.cast(k, dtype=q.dtype) scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) @@ -263,30 +260,18 @@ class TFXLMMainLayer(tf.keras.layers.Layer): self.hidden_dim = self.dim * 4 # 2048 by default self.n_heads = config.n_heads # 8 by default self.n_layers = config.n_layers + self.max_position_embeddings = config.max_position_embeddings + self.embed_init_std = config.embed_init_std assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads" # embeddings self.dropout = tf.keras.layers.Dropout(config.dropout) self.attention_dropout = tf.keras.layers.Dropout(config.attention_dropout) - self.position_embeddings = tf.keras.layers.Embedding( - config.max_position_embeddings, - self.dim, - embeddings_initializer=get_initializer(config.embed_init_std), - name="position_embeddings", - ) if config.sinusoidal_embeddings: raise NotImplementedError # create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight) - if config.n_langs > 1 and config.use_lang_emb: - self.lang_embeddings = tf.keras.layers.Embedding( - self.n_langs, - self.dim, - embeddings_initializer=get_initializer(config.embed_init_std), - name="lang_embeddings", - ) - self.embeddings = TFSharedEmbeddings( self.n_words, self.dim, initializer_range=config.embed_init_std, name="embeddings" ) # padding_idx=self.pad_index) @@ -326,6 +311,24 @@ class TFXLMMainLayer(tf.keras.layers.Layer): if self.attentions[int(layer)].n_heads == config.n_heads: self.prune_heads({int(layer): list(map(int, heads))}) + def build(self, input_shape): + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.dim], + initializer=get_initializer(self.embed_init_std), + ) + + if self.n_langs > 1 and self.use_lang_emb: + with tf.name_scope("lang_embeddings"): + self.lang_embeddings = self.add_weight( + name="embeddings", + shape=[self.n_langs, self.dim], + initializer=get_initializer(self.embed_init_std), + ) + + super().build(input_shape) + def get_input_embeddings(self): return self.embeddings @@ -389,17 +392,18 @@ class TFXLMMainLayer(tf.keras.layers.Layer): if inputs["lengths"] is None: if inputs["input_ids"] is not None: inputs["lengths"] = tf.reduce_sum( - tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=tf.int32), axis=1 + tf.cast(tf.not_equal(inputs["input_ids"], self.pad_index), dtype=inputs["input_ids"].dtype), axis=1 ) else: - inputs["lengths"] = tf.convert_to_tensor([slen] * bs, tf.int32) + inputs["lengths"] = tf.convert_to_tensor([slen] * bs) # mask = input_ids != self.pad_index # check inputs # assert shape_list(lengths)[0] == bs - tf.debugging.assert_equal( - shape_list(inputs["lengths"])[0], bs - ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched" + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(inputs["lengths"])[0], bs + ), f"Expected batch size {shape_list(inputs['lengths'])[0]} and received batch size {bs} mismatched" # assert lengths.max().item() <= slen # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # assert (src_enc is None) == (src_len is None) @@ -415,7 +419,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer): # position_ids if inputs["position_ids"] is None: inputs["position_ids"] = tf.expand_dims(tf.range(slen), axis=0) - else: + inputs["position_ids"] = tf.tile(inputs["position_ids"], (bs, 1)) + + if tf.executing_eagerly(): # assert shape_list(position_ids) == [bs, slen] # (slen, bs) tf.debugging.assert_equal( shape_list(inputs["position_ids"]), [bs, slen] @@ -423,7 +429,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer): # position_ids = position_ids.transpose(0, 1) # langs - if inputs["langs"] is not None: + if inputs["langs"] is not None and tf.executing_eagerly(): # assert shape_list(langs) == [bs, slen] # (slen, bs) tf.debugging.assert_equal( shape_list(inputs["langs"]), [bs, slen] @@ -454,15 +460,16 @@ class TFXLMMainLayer(tf.keras.layers.Layer): if inputs["inputs_embeds"] is None: inputs["inputs_embeds"] = self.embeddings(inputs["input_ids"]) - tensor = inputs["inputs_embeds"] + self.position_embeddings(inputs["position_ids"]) + tensor = inputs["inputs_embeds"] + tf.gather(self.position_embeddings, inputs["position_ids"]) if inputs["langs"] is not None and self.use_lang_emb and self.n_langs > 1: - tensor = tensor + self.lang_embeddings(inputs["langs"]) + tensor = tensor + tf.gather(self.lang_embeddings, inputs["langs"]) if inputs["token_type_ids"] is not None: tensor = tensor + self.embeddings(inputs["token_type_ids"]) tensor = self.layer_norm_emb(tensor) tensor = self.dropout(tensor, training=inputs["training"]) + mask = tf.cast(mask, dtype=tensor.dtype) tensor = tensor * tf.expand_dims(mask, axis=-1) # transformer layers @@ -837,7 +844,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): lang_id = self.config.lang_id effective_batch_size = inputs.shape[0] - mask_token = tf.ones((effective_batch_size, 1), dtype=tf.int32) * mask_token_id + mask_token = tf.fill((effective_batch_size, 1), 1) * mask_token_id inputs = tf.concat([inputs, mask_token], axis=1) if lang_id is not None: diff --git a/tests/test_modeling_tf_flaubert.py b/tests/test_modeling_tf_flaubert.py index 24802cfbab..cd2f053ca7 100644 --- a/tests/test_modeling_tf_flaubert.py +++ b/tests/test_modeling_tf_flaubert.py @@ -331,14 +331,6 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase): model = TFFlaubertModel.from_pretrained(model_name) self.assertIsNotNone(model) - def test_mixed_precision(self): - # TODO JP: Make Flaubert float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make Flaubert XLA compliant - pass - @require_tf @require_sentencepiece diff --git a/tests/test_modeling_tf_xlm.py b/tests/test_modeling_tf_xlm.py index 2860c99243..03dc1f0d46 100644 --- a/tests/test_modeling_tf_xlm.py +++ b/tests/test_modeling_tf_xlm.py @@ -327,14 +327,6 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) - def test_mixed_precision(self): - # TODO JP: Make XLM float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make XLM XLA compliant - pass - @slow def test_model_from_pretrained(self): for model_name in TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: