[blip-2] Fix dtype mismatch when keep in fp32 (#37068)
* fix fp32 BLIP2 * no need to reorder that * check for `Noneness` as well before casting dtype
This commit is contained in:
committed by
GitHub
parent
aa3778afc2
commit
52cc204dd7
@@ -1238,6 +1238,9 @@ class Blip2TextEmbeddings(nn.Module):
|
||||
embeddings += position_embeddings
|
||||
|
||||
if query_embeds is not None:
|
||||
# `query_embeds` are kept in fp32 when we use it with Qformer
|
||||
if query_embeds.dtype != embeddings.dtype:
|
||||
query_embeds = query_embeds.to(embeddings.dtype)
|
||||
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
||||
else:
|
||||
embeddings = query_embeds
|
||||
@@ -1385,6 +1388,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_hidden_states is not None:
|
||||
# Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already
|
||||
if encoder_hidden_states.dtype != query_embeds.dtype:
|
||||
encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype)
|
||||
|
||||
if isinstance(encoder_hidden_states, list):
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||
else:
|
||||
@@ -1447,7 +1454,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
class Blip2Model(Blip2PreTrainedModel):
|
||||
config_class = Blip2Config
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@@ -1728,6 +1735,10 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_model_attention_mask = torch.ones(
|
||||
@@ -1799,7 +1810,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
)
|
||||
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||
supports_gradient_checkpointing = False
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@@ -1898,7 +1909,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||
)
|
||||
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@@ -2019,7 +2030,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@@ -2192,6 +2203,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_model_attention_mask = torch.ones(
|
||||
@@ -2313,6 +2328,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
query_output = query_outputs.last_hidden_state
|
||||
|
||||
# Qformer is kept in fp32, we downcast the output back if needed
|
||||
if query_output.dtype != image_embeds.dtype:
|
||||
query_output = query_output.to(image_embeds.dtype)
|
||||
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
@@ -2372,7 +2391,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
||||
Reference in New Issue
Block a user