Byebye test_batching_equivalence's flakiness (#35729)
* fix * fix * skip * better error message --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -770,15 +770,6 @@ class ModelTesterMixin:
|
||||
different results: https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
|
||||
"""
|
||||
|
||||
def get_tensor_equivalence_function(batched_input):
|
||||
# models operating on continuous spaces have higher abs difference than LMs
|
||||
# instead, we can rely on cos distance for image/speech models, similar to `diffusers`
|
||||
if "input_ids" not in batched_input:
|
||||
return lambda tensor1, tensor2: (
|
||||
1.0 - F.cosine_similarity(tensor1.float().flatten(), tensor2.float().flatten(), dim=0, eps=1e-38)
|
||||
)
|
||||
return lambda tensor1, tensor2: torch.max(torch.abs(tensor1 - tensor2))
|
||||
|
||||
def recursive_check(batched_object, single_row_object, model_name, key):
|
||||
if isinstance(batched_object, (list, tuple)):
|
||||
for batched_object_value, single_row_object_value in zip(batched_object, single_row_object):
|
||||
@@ -793,6 +784,10 @@ class ModelTesterMixin:
|
||||
return
|
||||
elif batched_object.dim() == 0:
|
||||
return
|
||||
# do not compare int or bool outputs as they are mostly computed with max/argmax/topk methods which are
|
||||
# very sensitive to the inputs (e.g. tiny differences may give totally different results)
|
||||
elif not torch.is_floating_point(batched_object):
|
||||
return
|
||||
else:
|
||||
# indexing the first element does not always work
|
||||
# e.g. models that output similarity scores of size (N, M) would need to index [0, 0]
|
||||
@@ -810,19 +805,17 @@ class ModelTesterMixin:
|
||||
self.assertFalse(
|
||||
torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}"
|
||||
)
|
||||
self.assertTrue(
|
||||
(equivalence(batched_row, single_row_object)) <= 1e-03,
|
||||
msg=(
|
||||
f"Batched and Single row outputs are not equal in {model_name} for key={key}. "
|
||||
f"Difference={equivalence(batched_row, single_row_object)}."
|
||||
),
|
||||
)
|
||||
try:
|
||||
torch.testing.assert_close(batched_row, single_row_object, atol=1e-5, rtol=1e-5)
|
||||
except AssertionError as e:
|
||||
msg = f"Batched and Single row outputs are not equal in {model_name} for key={key}.\n\n"
|
||||
msg += str(e)
|
||||
raise AssertionError(msg)
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
set_config_for_less_flaky_test(config)
|
||||
equivalence = get_tensor_equivalence_function(batched_input)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
|
||||
Reference in New Issue
Block a user