From 52cc204dd7fbd671452448028aae6262cea74dc2 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 28 Mar 2025 15:52:11 +0100 Subject: [PATCH] [blip-2] Fix dtype mismatch when keep in fp32 (#37068) * fix fp32 BLIP2 * no need to reorder that * check for `Noneness` as well before casting dtype --- .../models/blip_2/modeling_blip_2.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index ab5a2a9abd..de15c0d1ed 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1238,6 +1238,9 @@ class Blip2TextEmbeddings(nn.Module): embeddings += position_embeddings if query_embeds is not None: + # `query_embeds` are kept in fp32 when we use it with Qformer + if query_embeds.dtype != embeddings.dtype: + query_embeds = query_embeds.to(embeddings.dtype) embeddings = torch.cat((query_embeds, embeddings), dim=1) else: embeddings = query_embeds @@ -1385,6 +1388,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel): # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: + # Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already + if encoder_hidden_states.dtype != query_embeds.dtype: + encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype) + if isinstance(encoder_hidden_states, list): encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() else: @@ -1447,7 +1454,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel): class Blip2Model(Blip2PreTrainedModel): config_class = Blip2Config main_input_name = "pixel_values" - _keep_in_fp32_modules = ["query_tokens"] + _keep_in_fp32_modules = ["query_tokens", "qformer"] def __init__(self, config: Blip2Config): super().__init__(config) @@ -1728,6 +1735,10 @@ class Blip2Model(Blip2PreTrainedModel): ) query_output = query_outputs[0] + # Qformer is kept in fp32, we downcast the output back if needed + if query_output.dtype != image_embeds.dtype: + query_output = query_output.to(image_embeds.dtype) + # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) language_model_attention_mask = torch.ones( @@ -1799,7 +1810,7 @@ class Blip2Model(Blip2PreTrainedModel): ) class Blip2TextModelWithProjection(Blip2PreTrainedModel): supports_gradient_checkpointing = False - _keep_in_fp32_modules = ["query_tokens"] + _keep_in_fp32_modules = ["query_tokens", "qformer"] def __init__(self, config: Blip2Config): super().__init__(config) @@ -1898,7 +1909,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel): ) class Blip2VisionModelWithProjection(Blip2PreTrainedModel): main_input_name = "pixel_values" - _keep_in_fp32_modules = ["query_tokens"] + _keep_in_fp32_modules = ["query_tokens", "qformer"] def __init__(self, config: Blip2Config): super().__init__(config) @@ -2019,7 +2030,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) - _keep_in_fp32_modules = ["query_tokens"] + _keep_in_fp32_modules = ["query_tokens", "qformer"] def __init__(self, config: Blip2Config): super().__init__(config) @@ -2192,6 +2203,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): ) query_output = query_outputs[0] + # Qformer is kept in fp32, we downcast the output back if needed + if query_output.dtype != image_embeds.dtype: + query_output = query_output.to(image_embeds.dtype) + # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) language_model_attention_mask = torch.ones( @@ -2313,6 +2328,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): ) query_output = query_outputs.last_hidden_state + # Qformer is kept in fp32, we downcast the output back if needed + if query_output.dtype != image_embeds.dtype: + query_output = query_output.to(image_embeds.dtype) + language_model_inputs = self.language_projection(query_output) language_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device @@ -2372,7 +2391,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): ) class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): main_input_name = "pixel_values" - _keep_in_fp32_modules = ["query_tokens"] + _keep_in_fp32_modules = ["query_tokens", "qformer"] def __init__(self, config: Blip2Config): super().__init__(config)