Fix some bugs for two stage training of deformable detr (#25045)
* Update modeling_deformable_detr.py Fix bugs for two stage training * Update modeling_deformable_detr.py * Add test_two_stage_training to DeformableDetrModelTest --------- Co-authored-by: yupeng.jia <yupeng.jia@momenta.ai>
This commit is contained in:
@@ -1558,7 +1558,7 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
def get_proposal_pos_embed(self, proposals):
|
||||
"""Get the position embedding of the proposals."""
|
||||
|
||||
num_pos_feats = 128
|
||||
num_pos_feats = self.config.d_model // 2
|
||||
temperature = 10000
|
||||
scale = 2 * math.pi
|
||||
|
||||
@@ -1977,12 +1977,11 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
outputs_coord = outputs_coord_logits.sigmoid()
|
||||
outputs_classes.append(outputs_class)
|
||||
outputs_coords.append(outputs_coord)
|
||||
# Keep batch_size as first dimension
|
||||
outputs_class = torch.stack(outputs_classes, dim=1)
|
||||
outputs_coord = torch.stack(outputs_coords, dim=1)
|
||||
outputs_class = torch.stack(outputs_classes)
|
||||
outputs_coord = torch.stack(outputs_coords)
|
||||
|
||||
logits = outputs_class[:, -1]
|
||||
pred_boxes = outputs_coord[:, -1]
|
||||
logits = outputs_class[-1]
|
||||
pred_boxes = outputs_coord[-1]
|
||||
|
||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||
if labels is not None:
|
||||
@@ -2008,7 +2007,7 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
outputs_loss["auxiliary_outputs"] = auxiliary_outputs
|
||||
if self.config.two_stage:
|
||||
enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid()
|
||||
outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
|
||||
outputs_loss["enc_outputs"] = {"logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord}
|
||||
|
||||
loss_dict = criterion(outputs_loss, labels)
|
||||
# Fourth: compute total loss, as a weighted sum of the various losses
|
||||
@@ -2240,7 +2239,7 @@ class DeformableDetrLoss(nn.Module):
|
||||
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
||||
losses applied, see each loss' doc.
|
||||
"""
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
||||
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
@@ -2272,14 +2271,10 @@ class DeformableDetrLoss(nn.Module):
|
||||
enc_outputs = outputs["enc_outputs"]
|
||||
bin_targets = copy.deepcopy(targets)
|
||||
for bt in bin_targets:
|
||||
bt["labels"] = torch.zeros_like(bt["labels"])
|
||||
bt["class_labels"] = torch.zeros_like(bt["class_labels"])
|
||||
indices = self.matcher(enc_outputs, bin_targets)
|
||||
for loss in self.losses:
|
||||
kwargs = {}
|
||||
if loss == "labels":
|
||||
# Logging is enabled only for the last layer
|
||||
kwargs["log"] = False
|
||||
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs)
|
||||
l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes)
|
||||
l_dict = {k + "_enc": v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
|
||||
@@ -1463,7 +1463,7 @@ class DetaModel(DetaPreTrainedModel):
|
||||
def get_proposal_pos_embed(self, proposals):
|
||||
"""Get the position embedding of the proposals."""
|
||||
|
||||
num_pos_feats = 128
|
||||
num_pos_feats = self.config.d_model // 2
|
||||
temperature = 10000
|
||||
scale = 2 * math.pi
|
||||
|
||||
|
||||
@@ -544,6 +544,21 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_two_stage_training(self):
|
||||
model_class = DeformableDetrForObjectDetection
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
config.two_stage = True
|
||||
config.auxiliary_loss = True
|
||||
config.with_box_refine = True
|
||||
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user