[Grounding DINO] Add support for cross-attention in GroundingDinoMultiHeadAttention (#30364)
* Added cross attention support * Fixed dtypes * Fixed assumption * Moved to decoder
This commit is contained in:
@@ -687,3 +687,29 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-3))
|
||||
self.assertTrue(torch.allclose(results_cpu["boxes"], result_gpu["boxes"].cpu(), atol=1e-3))
|
||||
|
||||
def test_cross_attention_mask(self):
|
||||
model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(torch_device)
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
text1 = "a cat."
|
||||
text2 = "a remote control."
|
||||
text_batched = [text1, text2]
|
||||
|
||||
encoding1 = processor(images=image, text=text1, return_tensors="pt").to(torch_device)
|
||||
encoding2 = processor(images=image, text=text2, return_tensors="pt").to(torch_device)
|
||||
# If we batch the text and cross attention masking is working the batched result should be equal to
|
||||
# The singe text result
|
||||
encoding_batched = processor(
|
||||
images=[image] * len(text_batched), text=text_batched, padding="longest", return_tensors="pt"
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs1 = model(**encoding1)
|
||||
outputs2 = model(**encoding2)
|
||||
outputs_batched = model(**encoding_batched)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs1.logits, outputs_batched.logits[:1], atol=1e-3))
|
||||
# For some reason 12 elements are > 1e-3, but the rest are fine
|
||||
self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))
|
||||
|
||||
Reference in New Issue
Block a user