[Grounding DINO] Add support for cross-attention in GroundingDinoMultiHeadAttention (#30364)
* Added cross attention support * Fixed dtypes * Fixed assumption * Moved to decoder
This commit is contained in:
@@ -818,7 +818,7 @@ class GroundingDinoTextEnhancerLayer(nn.Module):
|
|||||||
attention_masks = attention_masks[:, None, :, :]
|
attention_masks = attention_masks[:, None, :, :]
|
||||||
attention_masks = attention_masks.repeat(1, self.num_heads, 1, 1)
|
attention_masks = attention_masks.repeat(1, self.num_heads, 1, 1)
|
||||||
|
|
||||||
dtype = torch.float16
|
dtype = hidden_states.dtype
|
||||||
attention_masks = attention_masks.to(dtype=dtype) # fp16 compatibility
|
attention_masks = attention_masks.to(dtype=dtype) # fp16 compatibility
|
||||||
attention_masks = (1.0 - attention_masks) * torch.finfo(dtype).min
|
attention_masks = (1.0 - attention_masks) * torch.finfo(dtype).min
|
||||||
|
|
||||||
@@ -1425,12 +1425,11 @@ class GroundingDinoDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
# Cross-Attention Text
|
# Cross-Attention Text
|
||||||
queries = self.with_pos_embed(hidden_states, position_embeddings)
|
queries = self.with_pos_embed(hidden_states, position_embeddings)
|
||||||
|
|
||||||
hidden_states, text_cross_attn_weights = self.encoder_attn_text(
|
hidden_states, text_cross_attn_weights = self.encoder_attn_text(
|
||||||
queries=queries,
|
queries=queries,
|
||||||
keys=text_encoder_hidden_states,
|
keys=text_encoder_hidden_states,
|
||||||
values=text_encoder_hidden_states,
|
values=text_encoder_hidden_states,
|
||||||
# attention_mask=text_encoder_attention_mask, # TODO fix cross-attention mask here
|
attention_mask=text_encoder_attention_mask,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1893,6 +1892,16 @@ class GroundingDinoDecoder(GroundingDinoPreTrainedModel):
|
|||||||
intermediate = ()
|
intermediate = ()
|
||||||
intermediate_reference_points = ()
|
intermediate_reference_points = ()
|
||||||
|
|
||||||
|
if text_encoder_attention_mask is not None:
|
||||||
|
dtype = text_encoder_hidden_states.dtype
|
||||||
|
|
||||||
|
text_encoder_attention_mask = text_encoder_attention_mask[:, None, None, :]
|
||||||
|
text_encoder_attention_mask = text_encoder_attention_mask.repeat(
|
||||||
|
1, self.config.decoder_attention_heads, self.config.num_queries, 1
|
||||||
|
)
|
||||||
|
text_encoder_attention_mask = text_encoder_attention_mask.to(dtype=dtype)
|
||||||
|
text_encoder_attention_mask = text_encoder_attention_mask * torch.finfo(dtype).min
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
num_coordinates = reference_points.shape[-1]
|
num_coordinates = reference_points.shape[-1]
|
||||||
if num_coordinates == 4:
|
if num_coordinates == 4:
|
||||||
|
|||||||
@@ -687,3 +687,29 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))
|
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))
|
||||||
self.assertTrue(torch.allclose(results_cpu["boxes"], result_gpu["boxes"].cpu(), atol=1e-3))
|
self.assertTrue(torch.allclose(results_cpu["boxes"], result_gpu["boxes"].cpu(), atol=1e-3))
|
||||||
|
|
||||||
|
def test_cross_attention_mask(self):
|
||||||
|
model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(torch_device)
|
||||||
|
|
||||||
|
processor = self.default_processor
|
||||||
|
image = prepare_img()
|
||||||
|
text1 = "a cat."
|
||||||
|
text2 = "a remote control."
|
||||||
|
text_batched = [text1, text2]
|
||||||
|
|
||||||
|
encoding1 = processor(images=image, text=text1, return_tensors="pt").to(torch_device)
|
||||||
|
encoding2 = processor(images=image, text=text2, return_tensors="pt").to(torch_device)
|
||||||
|
# If we batch the text and cross attention masking is working the batched result should be equal to
|
||||||
|
# The singe text result
|
||||||
|
encoding_batched = processor(
|
||||||
|
images=[image] * len(text_batched), text=text_batched, padding="longest", return_tensors="pt"
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs1 = model(**encoding1)
|
||||||
|
outputs2 = model(**encoding2)
|
||||||
|
outputs_batched = model(**encoding_batched)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(outputs1.logits, outputs_batched.logits[:1], atol=1e-3))
|
||||||
|
# For some reason 12 elements are > 1e-3, but the rest are fine
|
||||||
|
self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))
|
||||||
|
|||||||
Reference in New Issue
Block a user