From 5ac264d8a8e014a7873c72fee91006d9e91a7bfd Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 13 Mar 2024 22:52:49 +0500 Subject: [PATCH] Fix batching tests for new models (Mamba and SegGPT) (#29633) * fix batchinng tests for new models * Update tests/models/seggpt/test_modeling_seggpt.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/models/seggpt/test_modeling_seggpt.py | 54 +++++++++++++++++++++ tests/test_modeling_common.py | 4 +- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/tests/models/seggpt/test_modeling_seggpt.py b/tests/models/seggpt/test_modeling_seggpt.py index 0cb36ea534..5f7920f9a3 100644 --- a/tests/models/seggpt/test_modeling_seggpt.py +++ b/tests/models/seggpt/test_modeling_seggpt.py @@ -245,6 +245,60 @@ class SegGptModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): check_hidden_states_output(inputs_dict, config, model_class) + def test_batching_equivalence(self): + def recursive_check(batched_object, single_row_object, model_name, key): + if isinstance(batched_object, (list, tuple)): + for batched_object_value, single_row_object_value in zip(batched_object, single_row_object): + recursive_check(batched_object_value, single_row_object_value, model_name, key) + else: + batched_row = batched_object[:1] + self.assertFalse( + torch.isnan(batched_row).any(), f"Batched output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(batched_row).any(), f"Batched output has `inf` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isnan(single_row_object).any(), f"Single row output has `nan` in {model_name} for key={key}" + ) + self.assertFalse( + torch.isinf(single_row_object).any(), f"Single row output has `inf` in {model_name} for key={key}" + ) + self.assertTrue( + torch.max(torch.abs(batched_row - single_row_object)) <= 1e-03, + msg=( + f"Batched and Single row outputs are not equal in {model_name} for key={key}. " + f"Difference={torch.max(torch.abs(batched_row - single_row_object))}." + ), + ) + + config, batched_input = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + config.output_hidden_states = True + + model_name = model_class.__name__ + batched_input_prepared = self._prepare_for_class(batched_input, model_class) + model = model_class(config).to(torch_device).eval() + + batch_size = self.model_tester.batch_size + single_row_input = {} + for key, value in batched_input_prepared.items(): + if isinstance(value, torch.Tensor) and value.shape[0] % batch_size == 0: + single_batch_shape = value.shape[0] // batch_size + single_row_input[key] = value[:single_batch_shape] + + with torch.no_grad(): + model_batched_output = model(**batched_input_prepared) + model_row_output = model(**single_row_input) + + for key in model_batched_output: + # the first hidden state in SegGPT has weird hack of adding first half of batch with second half + if key == "hidden_states": + model_batched_output[key] = model_batched_output[key][1:] + model_row_output[key] = model_row_output[key][1:] + recursive_check(model_batched_output[key], model_row_output[key], model_name, key) + @slow def test_model_from_pretrained(self): for model_name in SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 17865cf10f..a96ad61a34 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -720,8 +720,8 @@ class ModelTesterMixin: batched_object.values(), single_row_object.values() ): recursive_check(batched_object_value, single_row_object_value, model_name, key) - # do not compare returned loss (0-dim tensor) or codebook ids (int) - elif batched_object is None or isinstance(batched_object, int): + # do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects + elif batched_object is None or not isinstance(batched_object, torch.Tensor): return elif batched_object.dim() == 0: return