Fix some bugs for two stage training of deformable detr (#25045)
* Update modeling_deformable_detr.py Fix bugs for two stage training * Update modeling_deformable_detr.py * Add test_two_stage_training to DeformableDetrModelTest --------- Co-authored-by: yupeng.jia <yupeng.jia@momenta.ai>
This commit is contained in:
@@ -1558,7 +1558,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
|||||||
def get_proposal_pos_embed(self, proposals):
|
def get_proposal_pos_embed(self, proposals):
|
||||||
"""Get the position embedding of the proposals."""
|
"""Get the position embedding of the proposals."""
|
||||||
|
|
||||||
num_pos_feats = 128
|
num_pos_feats = self.config.d_model // 2
|
||||||
temperature = 10000
|
temperature = 10000
|
||||||
scale = 2 * math.pi
|
scale = 2 * math.pi
|
||||||
|
|
||||||
@@ -1977,12 +1977,11 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|||||||
outputs_coord = outputs_coord_logits.sigmoid()
|
outputs_coord = outputs_coord_logits.sigmoid()
|
||||||
outputs_classes.append(outputs_class)
|
outputs_classes.append(outputs_class)
|
||||||
outputs_coords.append(outputs_coord)
|
outputs_coords.append(outputs_coord)
|
||||||
# Keep batch_size as first dimension
|
outputs_class = torch.stack(outputs_classes)
|
||||||
outputs_class = torch.stack(outputs_classes, dim=1)
|
outputs_coord = torch.stack(outputs_coords)
|
||||||
outputs_coord = torch.stack(outputs_coords, dim=1)
|
|
||||||
|
|
||||||
logits = outputs_class[:, -1]
|
logits = outputs_class[-1]
|
||||||
pred_boxes = outputs_coord[:, -1]
|
pred_boxes = outputs_coord[-1]
|
||||||
|
|
||||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@@ -2008,7 +2007,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|||||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||||
if self.config.two_stage:
|
if self.config.two_stage:
|
||||||
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
|
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
|
||||||
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
|
outputs_loss["enc_outputs"] = {"logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
|
||||||
|
|
||||||
loss_dict = criterion(outputs_loss, labels)
|
loss_dict = criterion(outputs_loss, labels)
|
||||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||||
@@ -2240,7 +2239,7 @@ class DeformableDetrLoss(nn.Module):
|
|||||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||||
losses applied, see each loss' doc.
|
losses applied, see each loss' doc.
|
||||||
"""
|
"""
|
||||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"}
|
||||||
|
|
||||||
# Retrieve the matching between the outputs of the last layer and the targets
|
# Retrieve the matching between the outputs of the last layer and the targets
|
||||||
indices = self.matcher(outputs_without_aux, targets)
|
indices = self.matcher(outputs_without_aux, targets)
|
||||||
@@ -2272,14 +2271,10 @@ class DeformableDetrLoss(nn.Module):
|
|||||||
enc_outputs = outputs["enc_outputs"]
|
enc_outputs = outputs["enc_outputs"]
|
||||||
bin_targets = copy.deepcopy(targets)
|
bin_targets = copy.deepcopy(targets)
|
||||||
for bt in bin_targets:
|
for bt in bin_targets:
|
||||||
bt["labels"] = torch.zeros_like(bt["labels"])
|
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
|
||||||
indices = self.matcher(enc_outputs, bin_targets)
|
indices = self.matcher(enc_outputs, bin_targets)
|
||||||
for loss in self.losses:
|
for loss in self.losses:
|
||||||
kwargs = {}
|
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
|
||||||
if loss == "labels":
|
|
||||||
# Logging is enabled only for the last layer
|
|
||||||
kwargs["log"] = False
|
|
||||||
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
|
|
||||||
l_dict = {k + "_enc": v for k, v in l_dict.items()}
|
l_dict = {k + "_enc": v for k, v in l_dict.items()}
|
||||||
losses.update(l_dict)
|
losses.update(l_dict)
|
||||||
|
|
||||||
|
|||||||
@@ -1463,7 +1463,7 @@ class DetaModel(DetaPreTrainedModel):
|
|||||||
def get_proposal_pos_embed(self, proposals):
|
def get_proposal_pos_embed(self, proposals):
|
||||||
"""Get the position embedding of the proposals."""
|
"""Get the position embedding of the proposals."""
|
||||||
|
|
||||||
num_pos_feats = 128
|
num_pos_feats = self.config.d_model // 2
|
||||||
temperature = 10000
|
temperature = 10000
|
||||||
scale = 2 * math.pi
|
scale = 2 * math.pi
|
||||||
|
|
||||||
|
|||||||
@@ -544,6 +544,21 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_two_stage_training(self):
|
||||||
|
model_class = DeformableDetrForObjectDetection
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
config.two_stage = True
|
||||||
|
config.auxiliary_loss = True
|
||||||
|
config.with_box_refine = True
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
loss = model(**inputs).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
TOLERANCE = 1e-4
|
TOLERANCE = 1e-4
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user