Add test to ensure models can take int64 inputs (#17210)
* Add test to ensure models can take int64 inputs * is_integer is an attribute, not a method * Fix test when some inputs aren't tensors * Add casts to blenderbot and blenderbot-small * Add casts to the other failing models
This commit is contained in:
@@ -1287,7 +1287,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = tf.where(
|
labels = tf.where(
|
||||||
labels == self.config.pad_token_id,
|
labels == self.config.pad_token_id,
|
||||||
tf.fill(shape_list(labels), -100),
|
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
|
||||||
labels,
|
labels,
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|||||||
@@ -1265,7 +1265,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = tf.where(
|
labels = tf.where(
|
||||||
labels == self.config.pad_token_id,
|
labels == self.config.pad_token_id,
|
||||||
tf.fill(shape_list(labels), -100),
|
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
|
||||||
labels,
|
labels,
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|||||||
@@ -182,8 +182,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
mask = padding_mask
|
mask = padding_mask
|
||||||
else:
|
else:
|
||||||
# assert lengths.max().item() <= slen
|
# assert lengths.max().item() <= slen
|
||||||
alen = tf.range(slen)
|
alen = tf.range(slen, dtype=lengths.dtype)
|
||||||
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
|
mask = alen < tf.expand_dims(lengths, axis=1)
|
||||||
|
|
||||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||||
if causal:
|
if causal:
|
||||||
|
|||||||
@@ -1300,7 +1300,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = tf.where(
|
labels = tf.where(
|
||||||
labels == self.config.pad_token_id,
|
labels == self.config.pad_token_id,
|
||||||
tf.fill(shape_list(labels), -100),
|
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
|
||||||
labels,
|
labels,
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|||||||
@@ -1317,7 +1317,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
|
|||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels = tf.where(
|
labels = tf.where(
|
||||||
labels == self.config.pad_token_id,
|
labels == self.config.pad_token_id,
|
||||||
tf.fill(shape_list(labels), -100),
|
tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
|
||||||
labels,
|
labels,
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|||||||
@@ -1726,7 +1726,10 @@ class ProductIndexMap(IndexMap):
|
|||||||
raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.")
|
raise ValueError("outer_index.batch_dims and inner_index.batch_dims " "must be the same.")
|
||||||
|
|
||||||
super(ProductIndexMap, self).__init__(
|
super(ProductIndexMap, self).__init__(
|
||||||
indices=(inner_index.indices + outer_index.indices * inner_index.num_segments),
|
indices=(
|
||||||
|
inner_index.indices
|
||||||
|
+ outer_index.indices * tf.cast(inner_index.num_segments, inner_index.indices.dtype)
|
||||||
|
),
|
||||||
num_segments=inner_index.num_segments * outer_index.num_segments,
|
num_segments=inner_index.num_segments * outer_index.num_segments,
|
||||||
batch_dims=inner_index.batch_dims,
|
batch_dims=inner_index.batch_dims,
|
||||||
)
|
)
|
||||||
@@ -1785,7 +1788,7 @@ def flatten(index, name="segmented_flatten"):
|
|||||||
for _ in range(index.batch_dims, index.indices.shape.rank):
|
for _ in range(index.batch_dims, index.indices.shape.rank):
|
||||||
offset = tf.expand_dims(offset, -1)
|
offset = tf.expand_dims(offset, -1)
|
||||||
|
|
||||||
indices = offset + index.indices
|
indices = tf.cast(offset, index.indices.dtype) + index.indices
|
||||||
return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0)
|
return IndexMap(indices=tf.reshape(indices, [-1]), num_segments=index.num_segments * batch_size, batch_dims=0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _gather_logprob(logprob, target):
|
def _gather_logprob(logprob, target):
|
||||||
lp_size = shape_list(logprob)
|
lp_size = shape_list(logprob)
|
||||||
r = tf.range(lp_size[0])
|
r = tf.range(lp_size[0], dtype=target.dtype)
|
||||||
idx = tf.stack([r, target], 1)
|
idx = tf.stack([r, target], 1)
|
||||||
return tf.gather_nd(logprob, idx)
|
return tf.gather_nd(logprob, idx)
|
||||||
|
|
||||||
|
|||||||
@@ -92,8 +92,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
mask = padding_mask
|
mask = padding_mask
|
||||||
else:
|
else:
|
||||||
# assert lengths.max().item() <= slen
|
# assert lengths.max().item() <= slen
|
||||||
alen = tf.range(slen)
|
alen = tf.range(slen, dtype=lengths.dtype)
|
||||||
mask = tf.math.less(alen, tf.expand_dims(lengths, axis=1))
|
mask = alen < tf.expand_dims(lengths, axis=1)
|
||||||
|
|
||||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||||
if causal:
|
if causal:
|
||||||
|
|||||||
@@ -1372,6 +1372,26 @@ class TFModelTesterMixin:
|
|||||||
val_loss2 = history2.history["val_loss"][0]
|
val_loss2 = history2.history["val_loss"][0]
|
||||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||||
|
|
||||||
|
def test_int64_inputs(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
prepared_for_class = self._prepare_for_class(
|
||||||
|
inputs_dict.copy(),
|
||||||
|
model_class,
|
||||||
|
return_labels=True if "labels" in inspect.signature(model_class.call).parameters.keys() else False,
|
||||||
|
)
|
||||||
|
if not any(
|
||||||
|
[tensor.dtype.is_integer for tensor in prepared_for_class.values() if isinstance(tensor, tf.Tensor)]
|
||||||
|
):
|
||||||
|
return # No integer inputs means no need for this test
|
||||||
|
|
||||||
|
prepared_for_class = {
|
||||||
|
key: tf.cast(tensor, tf.int64) if isinstance(tensor, tf.Tensor) and tensor.dtype.is_integer else tensor
|
||||||
|
for key, tensor in prepared_for_class.items()
|
||||||
|
}
|
||||||
|
model = model_class(config)
|
||||||
|
model(**prepared_for_class) # No assertion, we're just checking this doesn't throw an error
|
||||||
|
|
||||||
def test_generate_with_headmasking(self):
|
def test_generate_with_headmasking(self):
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user