[blip-2] Fix dtype mismatch when keep in fp32 (#37068)
* fix fp32 BLIP2 * no need to reorder that * check for `Noneness` as well before casting dtype
This commit is contained in:
committed by
GitHub
parent
aa3778afc2
commit
52cc204dd7
@@ -1238,6 +1238,9 @@ class Blip2TextEmbeddings(nn.Module):
|
|||||||
embeddings += position_embeddings
|
embeddings += position_embeddings
|
||||||
|
|
||||||
if query_embeds is not None:
|
if query_embeds is not None:
|
||||||
|
# `query_embeds` are kept in fp32 when we use it with Qformer
|
||||||
|
if query_embeds.dtype != embeddings.dtype:
|
||||||
|
query_embeds = query_embeds.to(embeddings.dtype)
|
||||||
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
||||||
else:
|
else:
|
||||||
embeddings = query_embeds
|
embeddings = query_embeds
|
||||||
@@ -1385,6 +1388,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
|||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
|
# Qformer and latent query tokens are kept in fp32. We cast `encoder_hidden_states` if not fp32 already
|
||||||
|
if encoder_hidden_states.dtype != query_embeds.dtype:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(query_embeds.dtype)
|
||||||
|
|
||||||
if isinstance(encoder_hidden_states, list):
|
if isinstance(encoder_hidden_states, list):
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||||
else:
|
else:
|
||||||
@@ -1447,7 +1454,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
|||||||
class Blip2Model(Blip2PreTrainedModel):
|
class Blip2Model(Blip2PreTrainedModel):
|
||||||
config_class = Blip2Config
|
config_class = Blip2Config
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
_keep_in_fp32_modules = ["query_tokens"]
|
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||||
|
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1728,6 +1735,10 @@ class Blip2Model(Blip2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
query_output = query_outputs[0]
|
query_output = query_outputs[0]
|
||||||
|
|
||||||
|
# Qformer is kept in fp32, we downcast the output back if needed
|
||||||
|
if query_output.dtype != image_embeds.dtype:
|
||||||
|
query_output = query_output.to(image_embeds.dtype)
|
||||||
|
|
||||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
language_model_inputs = self.language_projection(query_output)
|
language_model_inputs = self.language_projection(query_output)
|
||||||
language_model_attention_mask = torch.ones(
|
language_model_attention_mask = torch.ones(
|
||||||
@@ -1799,7 +1810,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||||
supports_gradient_checkpointing = False
|
supports_gradient_checkpointing = False
|
||||||
_keep_in_fp32_modules = ["query_tokens"]
|
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||||
|
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1898,7 +1909,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
_keep_in_fp32_modules = ["query_tokens"]
|
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||||
|
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -2019,7 +2030,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
_supports_static_cache = True
|
_supports_static_cache = True
|
||||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||||
_keep_in_fp32_modules = ["query_tokens"]
|
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||||
|
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -2192,6 +2203,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
query_output = query_outputs[0]
|
query_output = query_outputs[0]
|
||||||
|
|
||||||
|
# Qformer is kept in fp32, we downcast the output back if needed
|
||||||
|
if query_output.dtype != image_embeds.dtype:
|
||||||
|
query_output = query_output.to(image_embeds.dtype)
|
||||||
|
|
||||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||||
language_model_inputs = self.language_projection(query_output)
|
language_model_inputs = self.language_projection(query_output)
|
||||||
language_model_attention_mask = torch.ones(
|
language_model_attention_mask = torch.ones(
|
||||||
@@ -2313,6 +2328,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
query_output = query_outputs.last_hidden_state
|
query_output = query_outputs.last_hidden_state
|
||||||
|
|
||||||
|
# Qformer is kept in fp32, we downcast the output back if needed
|
||||||
|
if query_output.dtype != image_embeds.dtype:
|
||||||
|
query_output = query_output.to(image_embeds.dtype)
|
||||||
|
|
||||||
language_model_inputs = self.language_projection(query_output)
|
language_model_inputs = self.language_projection(query_output)
|
||||||
language_attention_mask = torch.ones(
|
language_attention_mask = torch.ones(
|
||||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||||
@@ -2372,7 +2391,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
_keep_in_fp32_modules = ["query_tokens"]
|
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||||
|
|
||||||
def __init__(self, config: Blip2Config):
|
def __init__(self, config: Blip2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user