From 1125513a8da80d16e26cecfcbb508efc9038b5a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Ouazan?= <83456801+remi-or@users.noreply.github.com> Date: Wed, 2 Jul 2025 14:39:39 +0200 Subject: [PATCH] Blip2 fixes (#39080) * Fixed some devices errors * Fixed other device issues and more expectations * Reverted support flags * style * More granular support * Fixed some rebase stuff * add a not None check before .to --- .../models/blip_2/modeling_blip_2.py | 30 ++++++++++++--- tests/models/blip_2/test_modeling_blip_2.py | 38 +++++++++++++------ 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index c81b990b7d..48636496f9 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -415,6 +415,7 @@ class Blip2PreTrainedModel(PreTrainedModel): _no_split_modules = [ "Blip2Attention", "Blip2QFormerMultiHeadAttention", + "Blip2EncoderLayer", "Blip2TextEmbeddings", "T5Block", "OPTDecoderLayer", @@ -1262,6 +1263,7 @@ class Blip2Model(Blip2PreTrainedModel): config_class = Blip2Config main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn_2 = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -1646,6 +1648,7 @@ class Blip2Model(Blip2PreTrainedModel): class Blip2TextModelWithProjection(Blip2PreTrainedModel): supports_gradient_checkpointing = False _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn_2 = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -1738,6 +1741,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel): class Blip2VisionModelWithProjection(Blip2PreTrainedModel): main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn_2 = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -1857,6 +1861,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn_2 = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) @@ -2086,9 +2091,13 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + special_image_mask = ( + special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device) + ) + language_model_inputs = language_model_inputs.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( + special_image_mask, language_model_inputs + ) else: logger.warning_once( "Expanding inputs for image tokens in BLIP-2 should be done in processing. " @@ -2234,9 +2243,15 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): else: special_image_mask = input_ids == self.config.image_token_id - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) + special_image_mask = ( + special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device) + ) + language_model_inputs = language_model_inputs.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( + special_image_mask, language_model_inputs + ) + + attention_mask = attention_mask.to(language_attention_mask.device) else: logger.warning_once( "Expanding inputs for image tokens in BLIP-2 should be done in processing. " @@ -2259,6 +2274,8 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} if not self.language_model.config.is_encoder_decoder: + if input_ids is not None: + input_ids = input_ids.to(language_model_inputs.device) inputs["input_ids"] = input_ids outputs = self.language_model.generate(**inputs, **generate_kwargs) @@ -2275,6 +2292,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" _keep_in_fp32_modules = ["query_tokens", "qformer"] + _supports_flash_attn_2 = False # because self.qformer does not support FA2 def __init__(self, config: Blip2Config): super().__init__(config) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index af95bbb2c3..4cac2f3813 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -1786,7 +1786,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase): generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() # Test output - self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]) + expected_ids = [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118] + self.assertEqual(predictions[0].tolist(), [50265] * 32 + expected_ids) # 50265 is the img token id self.assertEqual("a woman sitting on the beach with a dog", generated_text) # image and context @@ -1797,10 +1798,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase): generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() # Test output - self.assertEqual( - predictions[0].tolist(), - [2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118], - ) + expected_ids = [2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118] + self.assertEqual(predictions[0].tolist(), [50265] * 32 + expected_ids) # 50265 is the img token id self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach") @require_torch_multi_accelerator @@ -1826,8 +1825,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase): generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() # Test output - self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) - self.assertEqual("woman playing with dog on the beach", generated_text) + expected_ids_and_text = Expectations( + { + ("cuda", None): ([0, 2335, 1556, 28, 1782, 30, 8, 2608, 1], "woman playing with dog on the beach"), + ("rocm", (9, 5)): ( + [0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1], + "a woman is playing with her dog on the beach", + ), + } + ).get_expectation() + self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0]) + self.assertEqual(generated_text, expected_ids_and_text[1]) # image and context prompt = "Question: which city is this? Answer:" @@ -1837,11 +1845,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase): generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() # Test output - self.assertEqual( - predictions[0].tolist(), - [0, 3, 7, 152, 67, 839, 1], - ) - self.assertEqual(generated_text, "san diego") + expected_ids_and_text = Expectations( + { + ("cuda", None): ([0, 3, 7, 152, 67, 839, 1], "san diego"), + ("rocm", (9, 5)): ( + [0, 3, 7, 152, 2515, 11389, 3523, 1], + "san francisco", # TODO: check if this is ok + ), + } + ).get_expectation() + self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0]) + self.assertEqual(generated_text, expected_ids_and_text[1]) def test_expansion_in_processing(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")