[GroundingDino] Fix grounding dino loss 🚨 (#31828)
* Starting to fix GroundingDinoLoss and GroundingDinoHungarianMatcher * More updates * More updates * fixed: GroundingDinoLoss * fixed: failing tests * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/grounding_dino/test_modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Addressed comments * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * add: cardinality loss and make box loss as copy from * change: default for reduction loss is sum * fix: vectorized generate fake box * fix copies * Addressed comments * addressed comments * addressed one-hot * Update tests/models/grounding_dino/test_modeling_grounding_dino.py Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> * Addressed comments * fixed test * Update src/transformers/models/grounding_dino/modeling_grounding_dino.py * Update tests/models/grounding_dino/test_modeling_grounding_dino.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Starting to fix GroundingDinoLoss and GroundingDinoHungarianMatcher * More updates * More updates * fixed: GroundingDinoLoss * add: cardinality loss and make box loss as copy from * fix copies * Revert "Update tests/models/grounding_dino/test_modeling_grounding_dino.py" This reverts commit aa74c4c57c430e54cc74c414d6269edb65c73e83. * [run-slow] groundigdino * remove nestedtensor * [run-slow] groundig_dino * [run-slow] grounding_dino * [run-slow] grounding_dino * [run-slow] grounding_dino * check * check * add: enconder intermediate outputs to ImageLoss forward * add: GroundingDinoForObjectDetectionLoss in the loss directory * make style * fix the loss function * remove class_reduction since it sum is default * remove class_reduction * Update src/transformers/loss/loss_grounding_dino.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * simple fix * Update src/transformers/loss/loss_grounding_dino.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * minor fix * Update src/transformers/loss/loss_for_object_detection.py --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> Co-authored-by: sangbumchoi <danielsejong55@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,8 @@ import math
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import (
|
||||
GroundingDinoConfig,
|
||||
SwinConfig,
|
||||
@@ -28,6 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_timm,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
@@ -37,14 +40,14 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import GroundingDinoForObjectDetection, GroundingDinoModel
|
||||
from transformers import GroundingDinoConfig, GroundingDinoForObjectDetection, GroundingDinoModel
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
@@ -54,6 +57,39 @@ if is_vision_available():
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
def generate_fake_bounding_boxes(n_boxes):
|
||||
"""Generate bounding boxes in the format (center_x, center_y, width, height)"""
|
||||
# Validate the input
|
||||
if not isinstance(n_boxes, int):
|
||||
raise ValueError("n_boxes must be an integer")
|
||||
if n_boxes <= 0:
|
||||
raise ValueError("n_boxes must be a positive integer")
|
||||
|
||||
# Generate random bounding boxes in the format (center_x, center_y, width, height)
|
||||
bounding_boxes = torch.rand((n_boxes, 4))
|
||||
|
||||
# Extract the components
|
||||
center_x = bounding_boxes[:, 0]
|
||||
center_y = bounding_boxes[:, 1]
|
||||
width = bounding_boxes[:, 2]
|
||||
height = bounding_boxes[:, 3]
|
||||
|
||||
# Ensure width and height do not exceed bounds
|
||||
width = torch.min(width, torch.tensor(1.0))
|
||||
height = torch.min(height, torch.tensor(1.0))
|
||||
|
||||
# Ensure the bounding box stays within the normalized space
|
||||
center_x = torch.where(center_x - width / 2 < 0, width / 2, center_x)
|
||||
center_x = torch.where(center_x + width / 2 > 1, 1 - width / 2, center_x)
|
||||
center_y = torch.where(center_y - height / 2 < 0, height / 2, center_y)
|
||||
center_y = torch.where(center_y + height / 2 > 1, 1 - height / 2, center_y)
|
||||
|
||||
# Combine back into bounding boxes
|
||||
bounding_boxes = torch.stack([center_x, center_y, width, height], dim=1)
|
||||
|
||||
return bounding_boxes
|
||||
|
||||
|
||||
class GroundingDinoModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -72,7 +108,7 @@ class GroundingDinoModelTester:
|
||||
num_channels=3,
|
||||
image_size=98,
|
||||
n_targets=8,
|
||||
num_labels=3,
|
||||
num_labels=2,
|
||||
num_feature_levels=4,
|
||||
encoder_n_points=2,
|
||||
decoder_n_points=6,
|
||||
@@ -115,7 +151,11 @@ class GroundingDinoModelTester:
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device)
|
||||
|
||||
input_ids = ids_tensor([self.batch_size, self.max_text_len], self.num_labels)
|
||||
# When using `GroundingDino` the text input template is '{label1}. {label2}. {label3. ... {labelN}.'
|
||||
# Therefore to avoid errors when running tests with `labels` `input_ids` have to follow this structure.
|
||||
# Otherwise when running `build_label_maps` it will throw an error when trying to split the input_ids into segments.
|
||||
input_ids = torch.tensor([101, 3869, 1012, 11420, 3869, 1012, 102], device=torch_device)
|
||||
input_ids = input_ids.unsqueeze(0).expand(self.batch_size, -1)
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
@@ -126,7 +166,7 @@ class GroundingDinoModelTester:
|
||||
target["class_labels"] = torch.randint(
|
||||
high=self.num_labels, size=(self.n_targets,), device=torch_device
|
||||
)
|
||||
target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device)
|
||||
target["boxes"] = generate_fake_bounding_boxes(self.n_targets).to(torch_device)
|
||||
target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device)
|
||||
labels.append(target)
|
||||
|
||||
@@ -317,7 +357,7 @@ class GroundingDinoModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 10
|
||||
correct_outlen = 12
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
@@ -677,6 +717,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(results["text_labels"], expected_labels)
|
||||
|
||||
@require_torch_accelerator
|
||||
@is_flaky()
|
||||
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
@@ -716,6 +757,7 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
torch.testing.assert_close(results_cpu["scores"], result_gpu["scores"].cpu(), rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(results_cpu["boxes"], result_gpu["boxes"].cpu(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@is_flaky()
|
||||
def test_cross_attention_mask(self):
|
||||
model = GroundingDinoForObjectDetection.from_pretrained("IDEA-Research/grounding-dino-tiny").to(torch_device)
|
||||
|
||||
@@ -740,4 +782,56 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
torch.testing.assert_close(outputs1.logits, outputs_batched.logits[:1], rtol=1e-3, atol=1e-3)
|
||||
# For some reason 12 elements are > 1e-3, but the rest are fine
|
||||
torch.testing.assert_close(outputs2.logits, outputs_batched.logits[1:], rtol=1.8e-3, atol=1.8e-3)
|
||||
self.assertTrue(torch.allclose(outputs2.logits, outputs_batched.logits[1:], atol=1.8e-3))
|
||||
|
||||
def test_grounding_dino_loss(self):
|
||||
ds = load_dataset("EduardoPacheco/aquarium-sample", split="train")
|
||||
image_processor = self.default_processor.image_processor
|
||||
tokenizer = self.default_processor.tokenizer
|
||||
id2label = {0: "fish", 1: "jellyfish", 2: "penguins", 3: "sharks", 4: "puffins", 5: "stingrays", 6: "starfish"}
|
||||
prompt = ". ".join(id2label.values()) + "."
|
||||
|
||||
text_inputs = tokenizer([prompt, prompt], return_tensors="pt")
|
||||
image_inputs = image_processor(images=ds["image"], annotations=ds["annotations"], return_tensors="pt")
|
||||
|
||||
# Passing auxiliary_loss=True to compare with the expected loss
|
||||
model = GroundingDinoForObjectDetection.from_pretrained(
|
||||
"IDEA-Research/grounding-dino-tiny",
|
||||
auxiliary_loss=True,
|
||||
)
|
||||
# Interested in the loss only
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**text_inputs, **image_inputs)
|
||||
|
||||
# Loss differs by CPU and GPU, also this can be changed in future.
|
||||
expected_loss_dict = {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5607),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2008),
|
||||
"loss_giou_4": torch.tensor(0.5836),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
}
|
||||
|
||||
expected_loss = torch.tensor(32482.2305)
|
||||
|
||||
for key in expected_loss_dict:
|
||||
self.assertTrue(torch.allclose(outputs.loss_dict[key], expected_loss_dict[key], atol=1e-3))
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user