From f04257fdbcb6ecb5a9bef75f4c2a8d2e8b5a6209 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 12 May 2022 16:09:25 +0100 Subject: [PATCH] Add test to ensure models can take int64 inputs (#17210) * Add test to ensure models can take int64 inputs * is_integer is an attribute, not a method * Fix test when some inputs aren't tensors * Add casts to blenderbot and blenderbot-small * Add casts to the other failing models --- .../blenderbot/modeling_tf_blenderbot.py | 2 +- .../modeling_tf_blenderbot_small.py | 2 +- .../models/flaubert/modeling_tf_flaubert.py | 4 ++-- .../models/mbart/modeling_tf_mbart.py | 2 +- .../models/pegasus/modeling_tf_pegasus.py | 2 +- .../models/tapas/modeling_tf_tapas.py | 7 +++++-- .../modeling_tf_transfo_xl_utilities.py | 2 +- .../models/xlm/modeling_tf_xlm.py | 4 ++-- tests/test_modeling_tf_common.py | 20 +++++++++++++++++++ 9 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index 7fa910a3eb..b4bceee3e2 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1287,7 +1287,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 612755882e..95078af4b9 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1265,7 +1265,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index d4bd3f53fd..bc49216221 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -182,8 +182,8 @@ def get_masks(slen, lengths, causal, padding_mask=None): mask = padding_mask else: # assert lengths.max().item() <= slen - alen = tf.range(slen) - mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1)) + alen = tf.range(slen, dtype=lengths.dtype) + mask = alen < tf.expand_dims(lengths, axis=1) # attention mask is the same as mask, or triangular inferior attention (causal) if causal: diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index b31ac1bd63..b7de8be6e6 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1300,7 +1300,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index be2539b3a9..2a1b7994b6 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1317,7 +1317,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua if labels is not None: labels = tf.where( labels == self.config.pad_token_id, - tf.fill(shape_list(labels), -100), + tf.cast(tf.fill(shape_list(labels), -100), labels.dtype), labels, ) use_cache = False diff --git a/src/transformers/models/tapas/modeling_tf_tapas.py b/src/transformers/models/tapas/modeling_tf_tapas.py index d2da064462..29cb63c3ad 100644 --- a/src/transformers/models/tapas/modeling_tf_tapas.py +++ b/src/transformers/models/tapas/modeling_tf_tapas.py @@ -1726,7 +1726,10 @@ class ProductIndexMap(IndexMap): raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.") super(ProductIndexMap, self).__init__( - indices=(inner_index.indices + outer_index.indices * inner_index.num_segments), + indices=( + inner_index.indices + + outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype) + ), num_segments=inner_index.num_segments * outer_index.num_segments, batch_dims=inner_index.batch_dims, ) @@ -1785,7 +1788,7 @@ def flatten(index, name="segmented_flatten"): for _ in range(index.batch_dims, index.indices.shape.rank): offset = tf.expand_dims(offset, -1) - indices = offset + index.indices + indices = tf.cast(offset, index.indices.dtype) + index.indices return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0) diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py index af95f348ec..dcfa84d0f9 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py @@ -111,7 +111,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): @staticmethod def _gather_logprob(logprob, target): lp_size = shape_list(logprob) - r = tf.range(lp_size[0]) + r = tf.range(lp_size[0], dtype=target.dtype) idx = tf.stack([r, target], 1) return tf.gather_nd(logprob, idx) diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index 24d32f798f..fa3a54b6cc 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None): mask = padding_mask else: # assert lengths.max().item() <= slen - alen = tf.range(slen) - mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1)) + alen = tf.range(slen, dtype=lengths.dtype) + mask = alen < tf.expand_dims(lengths, axis=1) # attention mask is the same as mask, or triangular inferior attention (causal) if causal: diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 0d38713e08..6edc6b20c2 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1372,6 +1372,26 @@ class TFModelTesterMixin: val_loss2 = history2.history["val_loss"][0] self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3)) + def test_int64_inputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + prepared_for_class = self._prepare_for_class( + inputs_dict.copy(), + model_class, + return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False, + ) + if not any( + [tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)] + ): + return # No integer inputs means no need for this test + + prepared_for_class = { + key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor + for key, tensor in prepared_for_class.items() + } + model = model_class(config) + model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error + def test_generate_with_headmasking(self): attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"] config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()