[blip-2] Fix dtype mismatch when keep in fp32 (#37068)

* fix fp32 BLIP2

* no need to reorder that

* check for `Noneness` as well before casting dtype
This commit is contained in:
Raushan Turganbay
2025-03-28 15:52:11 +01:00
committed by GitHub
parent aa3778afc2
commit 52cc204dd7

View File

@@ -1238,6 +1238,9 @@ class Blip2TextEmbeddings(nn.Module):
embeddings += position_embeddings embeddings += position_embeddings
if query_embeds is not None: if query_embeds is not None:
# `query_embeds` are kept in fp32 when we use it with Qformer
if query_embeds.dtype != embeddings.dtype:
query_embeds = query_embeds.to(embeddings.dtype)
embeddings = torch.cat((query_embeds, embeddings), dim=1) embeddings = torch.cat((query_embeds, embeddings), dim=1)
else: else:
embeddings = query_embeds embeddings = query_embeds
@@ -1385,6 +1388,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
# Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already
if encoder_hidden_states.dtype != query_embeds.dtype:
encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype)
if isinstance(encoder_hidden_states, list): if isinstance(encoder_hidden_states, list):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else: else:
@@ -1447,7 +1454,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
class Blip2Model(Blip2PreTrainedModel): class Blip2Model(Blip2PreTrainedModel):
config_class = Blip2Config config_class = Blip2Config
main_input_name = "pixel_values" main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"] _keep_in_fp32_modules = ["query_tokens", "qformer"]
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
super().__init__(config) super().__init__(config)
@@ -1728,6 +1735,10 @@ class Blip2Model(Blip2PreTrainedModel):
) )
query_output = query_outputs[0] query_output = query_outputs[0]
# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)
# step 3: use the language model, conditioned on the query outputs and the prompt # step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output) language_model_inputs = self.language_projection(query_output)
language_model_attention_mask = torch.ones( language_model_attention_mask = torch.ones(
@@ -1799,7 +1810,7 @@ class Blip2Model(Blip2PreTrainedModel):
) )
class Blip2TextModelWithProjection(Blip2PreTrainedModel): class Blip2TextModelWithProjection(Blip2PreTrainedModel):
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_keep_in_fp32_modules = ["query_tokens"] _keep_in_fp32_modules = ["query_tokens", "qformer"]
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
super().__init__(config) super().__init__(config)
@@ -1898,7 +1909,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
) )
class Blip2VisionModelWithProjection(Blip2PreTrainedModel): class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"] _keep_in_fp32_modules = ["query_tokens", "qformer"]
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
super().__init__(config) super().__init__(config)
@@ -2019,7 +2030,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
_supports_cache_class = True _supports_cache_class = True
_supports_static_cache = True _supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5) _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_keep_in_fp32_modules = ["query_tokens"] _keep_in_fp32_modules = ["query_tokens", "qformer"]
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
super().__init__(config) super().__init__(config)
@@ -2192,6 +2203,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
) )
query_output = query_outputs[0] query_output = query_outputs[0]
# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)
# step 3: use the language model, conditioned on the query outputs and the prompt # step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output) language_model_inputs = self.language_projection(query_output)
language_model_attention_mask = torch.ones( language_model_attention_mask = torch.ones(
@@ -2313,6 +2328,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
) )
query_output = query_outputs.last_hidden_state query_output = query_outputs.last_hidden_state
# Qformer is kept in fp32, we downcast the output back if needed
if query_output.dtype != image_embeds.dtype:
query_output = query_output.to(image_embeds.dtype)
language_model_inputs = self.language_projection(query_output) language_model_inputs = self.language_projection(query_output)
language_attention_mask = torch.ones( language_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
@@ -2372,7 +2391,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
) )
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"] _keep_in_fp32_modules = ["query_tokens", "qformer"]
def __init__(self, config: Blip2Config): def __init__(self, config: Blip2Config):
super().__init__(config) super().__init__(config)