From dd16acb8a3e93b643aa374c9fb80749f5235c1a6 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 14 Feb 2025 17:43:32 +0100 Subject: [PATCH] set `test_torchscript = False` for Blip2 testing (#35972) * just skip * fix * fix * fix --------- Co-authored-by: ydshieh --- tests/models/blip_2/test_modeling_blip_2.py | 227 +------------------- 1 file changed, 3 insertions(+), 224 deletions(-) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index a405a1f97f..17d14c8486 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -15,7 +15,6 @@ """Testing suite for the PyTorch BLIP-2 model.""" import inspect -import os import tempfile import unittest @@ -36,7 +35,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import is_torch_available, is_torch_sdpa_available, is_vision_available +from transformers.utils import is_torch_available, is_vision_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -477,7 +476,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT test_pruning = False test_resize_embeddings = False test_attention_outputs = False - test_torchscript = True + test_torchscript = False _is_composite = True def setUp(self): @@ -494,116 +493,6 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs) - def _create_and_check_torchscript(self, config, inputs_dict): - # overwrite because BLIP requires ipnut ids and pixel values as input - if not self.test_torchscript: - self.skipTest(reason="test_torchscript is set to `False`") - - configs_no_init = _config_zero_init(config) # To be sure we have no Nan - configs_no_init.torchscript = True - for model_class in self.all_model_classes: - for attn_implementation in ["eager", "sdpa"]: - if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()): - continue - - configs_no_init._attn_implementation = attn_implementation - model = model_class(config=configs_no_init) - model.to(torch_device) - model.eval() - inputs = self._prepare_for_class(inputs_dict, model_class) - - main_input_name = model_class.main_input_name - - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - main_input = inputs[main_input_name] - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - decoder_input_ids = inputs["decoder_input_ids"] - decoder_attention_mask = inputs["decoder_attention_mask"] - model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) - traced_model = torch.jit.trace( - model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) - ) - else: - main_input = inputs[main_input_name] - input_ids = inputs["input_ids"] - - if model.config._attn_implementation == "sdpa": - trace_input = {main_input_name: main_input, "input_ids": input_ids} - - if "attention_mask" in inputs: - trace_input["attention_mask"] = inputs["attention_mask"] - else: - self.skipTest(reason="testing SDPA without attention_mask is not supported") - - model(main_input, attention_mask=inputs["attention_mask"]) - # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. - traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) - else: - model(main_input, input_ids) - traced_model = torch.jit.trace(model, (main_input, input_ids)) - except RuntimeError: - self.fail("Couldn't trace module.") - - with tempfile.TemporaryDirectory() as tmp_dir_name: - pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") - - try: - torch.jit.save(traced_model, pt_file_name) - except Exception: - self.fail("Couldn't save module.") - - try: - loaded_model = torch.jit.load(pt_file_name) - except Exception: - self.fail("Couldn't load module.") - - model.to(torch_device) - model.eval() - - loaded_model.to(torch_device) - loaded_model.eval() - - model_state_dict = model.state_dict() - loaded_model_state_dict = loaded_model.state_dict() - - non_persistent_buffers = {} - for key in loaded_model_state_dict.keys(): - if key not in model_state_dict.keys(): - non_persistent_buffers[key] = loaded_model_state_dict[key] - - loaded_model_state_dict = { - key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers - } - - self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) - - model_buffers = list(model.buffers()) - for non_persistent_buffer in non_persistent_buffers.values(): - found_buffer = False - for i, model_buffer in enumerate(model_buffers): - if torch.equal(non_persistent_buffer, model_buffer): - found_buffer = True - break - - self.assertTrue(found_buffer) - model_buffers.pop(i) - - models_equal = True - for layer_name, p1 in model_state_dict.items(): - if layer_name in loaded_model_state_dict: - p2 = loaded_model_state_dict[layer_name] - if p1.data.ne(p2.data).sum() > 0: - models_equal = False - - self.assertTrue(models_equal) - - # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. - # (Even with this call, there are still memory leak by ~0.04MB) - self.clear_torch_jit_class_registry() - @unittest.skip(reason="Hidden_states is tested in individual model tests") def test_hidden_states_output(self): pass @@ -1015,7 +904,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi test_pruning = False test_resize_embeddings = True test_attention_outputs = False - test_torchscript = True + test_torchscript = False _is_composite = True # TODO: Fix the failed tests @@ -1049,116 +938,6 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs) - def _create_and_check_torchscript(self, config, inputs_dict): - # overwrite because BLIP requires ipnut ids and pixel values as input - if not self.test_torchscript: - self.skipTest(reason="test_torchscript is set to `False`") - - configs_no_init = _config_zero_init(config) # To be sure we have no Nan - configs_no_init.torchscript = True - for model_class in self.all_model_classes: - for attn_implementation in ["eager", "sdpa"]: - if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()): - continue - - configs_no_init._attn_implementation = attn_implementation - model = model_class(config=configs_no_init) - model.to(torch_device) - model.eval() - inputs = self._prepare_for_class(inputs_dict, model_class) - - main_input_name = model_class.main_input_name - - try: - if model.config.is_encoder_decoder: - model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward - main_input = inputs[main_input_name] - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - decoder_input_ids = inputs["decoder_input_ids"] - decoder_attention_mask = inputs["decoder_attention_mask"] - model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) - traced_model = torch.jit.trace( - model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) - ) - else: - main_input = inputs[main_input_name] - input_ids = inputs["input_ids"] - - if model.config._attn_implementation == "sdpa": - trace_input = {main_input_name: main_input, "input_ids": input_ids} - - if "attention_mask" in inputs: - trace_input["attention_mask"] = inputs["attention_mask"] - else: - self.skipTest(reason="testing SDPA without attention_mask is not supported") - - model(main_input, attention_mask=inputs["attention_mask"]) - # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1. - traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input) - else: - model(main_input, input_ids) - traced_model = torch.jit.trace(model, (main_input, input_ids)) - except RuntimeError: - self.fail("Couldn't trace module.") - - with tempfile.TemporaryDirectory() as tmp_dir_name: - pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") - - try: - torch.jit.save(traced_model, pt_file_name) - except Exception: - self.fail("Couldn't save module.") - - try: - loaded_model = torch.jit.load(pt_file_name) - except Exception: - self.fail("Couldn't load module.") - - model.to(torch_device) - model.eval() - - loaded_model.to(torch_device) - loaded_model.eval() - - model_state_dict = model.state_dict() - loaded_model_state_dict = loaded_model.state_dict() - - non_persistent_buffers = {} - for key in loaded_model_state_dict.keys(): - if key not in model_state_dict.keys(): - non_persistent_buffers[key] = loaded_model_state_dict[key] - - loaded_model_state_dict = { - key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers - } - - self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) - - model_buffers = list(model.buffers()) - for non_persistent_buffer in non_persistent_buffers.values(): - found_buffer = False - for i, model_buffer in enumerate(model_buffers): - if torch.equal(non_persistent_buffer, model_buffer): - found_buffer = True - break - - self.assertTrue(found_buffer) - model_buffers.pop(i) - - models_equal = True - for layer_name, p1 in model_state_dict.items(): - if layer_name in loaded_model_state_dict: - p2 = loaded_model_state_dict[layer_name] - if p1.data.ne(p2.data).sum() > 0: - models_equal = False - - self.assertTrue(models_equal) - - # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. - # (Even with this call, there are still memory leak by ~0.04MB) - self.clear_torch_jit_class_registry() - @unittest.skip(reason="Hidden_states is tested in individual model tests") def test_hidden_states_output(self): pass