From 899d8351f9926aa725b25a4b5625f07d7defc3c0 Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Fri, 5 Jan 2024 23:20:21 +0900 Subject: [PATCH] [DETA] Improvement and Sync from DETA especially for training (#27990) * [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 * fix : enable aux and enc loss in training pipeline * Add unsynced variables from original DETA for training * modification for passing CI test * make style * make fix * manual make fix * change deta_modeling_test of configuration 'two_stage' default to TRUE and minor change of dist checking * remove print * divide configuration in DetaModel and DetaForObjectDetection * image smaller size than 224 will give topk error * pred_boxes and logits should be equivalent to two_stage_num_proposals * add missing part in DetaConfig * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add docstring in configure and prettify TO DO part * change distribute related code to accelerate * Update src/transformers/models/deta/configuration_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/deta/test_modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * protect importing accelerate * change variable name to specific value * wrong import --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/deta/configuration_deta.py | 6 ++ .../models/deta/image_processing_deta.py | 2 +- src/transformers/models/deta/modeling_deta.py | 62 ++++++++++++------- tests/models/deta/test_modeling_deta.py | 44 ++++++++----- 4 files changed, 74 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/deta/configuration_deta.py b/src/transformers/models/deta/configuration_deta.py index 0d8e59e960..8a89a6ddc0 100644 --- a/src/transformers/models/deta/configuration_deta.py +++ b/src/transformers/models/deta/configuration_deta.py @@ -109,6 +109,10 @@ class DetaConfig(PretrainedConfig): based on the predictions from the previous layer. focal_alpha (`float`, *optional*, defaults to 0.25): Alpha parameter in the focal loss. + assign_first_stage (`bool`, *optional*, defaults to `True`): + Whether to assign each prediction i to the highest overlapping ground truth object if the overlap is larger than a threshold 0.7. + assign_second_stage (`bool`, *optional*, defaults to `True`): + Whether to assign second assignment procedure in the second stage closely follows the first stage assignment procedure. Examples: @@ -161,6 +165,7 @@ class DetaConfig(PretrainedConfig): two_stage_num_proposals=300, with_box_refine=True, assign_first_stage=True, + assign_second_stage=True, class_cost=1, bbox_cost=5, giou_cost=2, @@ -208,6 +213,7 @@ class DetaConfig(PretrainedConfig): self.two_stage_num_proposals = two_stage_num_proposals self.with_box_refine = with_box_refine self.assign_first_stage = assign_first_stage + self.assign_second_stage = assign_second_stage if two_stage is True and with_box_refine is False: raise ValueError("If two_stage is True, with_box_refine must be True.") # Hungarian matcher diff --git a/src/transformers/models/deta/image_processing_deta.py b/src/transformers/models/deta/image_processing_deta.py index bdd7ab1118..5fdcb8df50 100644 --- a/src/transformers/models/deta/image_processing_deta.py +++ b/src/transformers/models/deta/image_processing_deta.py @@ -1052,7 +1052,7 @@ class DetaImageProcessor(BaseImageProcessor): score = all_scores[b] lbls = all_labels[b] - pre_topk = score.topk(min(10000, len(score))).indices + pre_topk = score.topk(min(10000, num_queries * num_labels)).indices box = box[pre_topk] score = score[pre_topk] lbls = lbls[pre_topk] diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 8362b49eee..330ccfe3f0 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -38,7 +38,7 @@ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid -from ...utils import is_torchvision_available, logging, requires_backends +from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends from ..auto import AutoBackbone from .configuration_deta import DetaConfig @@ -46,6 +46,10 @@ from .configuration_deta import DetaConfig logger = logging.get_logger(__name__) +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + if is_vision_available(): from transformers.image_transforms import center_to_corners_format @@ -105,7 +109,6 @@ class DetaDecoderOutput(ModelOutput): @dataclass -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrModelOutput with DeformableDetr->Deta,Deformable DETR->DETA class DetaModelOutput(ModelOutput): """ Base class for outputs of the Deformable DETR encoder-decoder model. @@ -147,6 +150,8 @@ class DetaModelOutput(ModelOutput): foreground and background). enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): Logits of predicted bounding boxes coordinates in the first stage. + output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals. """ init_reference_points: torch.FloatTensor = None @@ -161,10 +166,10 @@ class DetaModelOutput(ModelOutput): encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None enc_outputs_class: Optional[torch.FloatTensor] = None enc_outputs_coord_logits: Optional[torch.FloatTensor] = None + output_proposals: Optional[torch.FloatTensor] = None @dataclass -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrObjectDetectionOutput with DeformableDetr->Deta class DetaObjectDetectionOutput(ModelOutput): """ Output type of [`DetaForObjectDetection`]. @@ -223,6 +228,8 @@ class DetaObjectDetectionOutput(ModelOutput): foreground and background). enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): Logits of predicted bounding boxes coordinates in the first stage. + output_proposals (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.two_stage=True`): + Logits of proposal bounding boxes coordinates in the gen_encoder_output_proposals. """ loss: Optional[torch.FloatTensor] = None @@ -242,6 +249,7 @@ class DetaObjectDetectionOutput(ModelOutput): encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None enc_outputs_class: Optional = None enc_outputs_coord_logits: Optional = None + output_proposals: Optional[torch.FloatTensor] = None def _get_clones(module, N): @@ -1632,6 +1640,7 @@ class DetaModel(DetaPreTrainedModel): batch_size, _, num_channels = encoder_outputs[0].shape enc_outputs_class = None enc_outputs_coord_logits = None + output_proposals = None if self.config.two_stage: object_query_embedding, output_proposals, level_ids = self.gen_encoder_output_proposals( encoder_outputs[0], ~mask_flatten, spatial_shapes @@ -1746,6 +1755,7 @@ class DetaModel(DetaPreTrainedModel): encoder_attentions=encoder_outputs.attentions, enc_outputs_class=enc_outputs_class, enc_outputs_coord_logits=enc_outputs_coord_logits, + output_proposals=output_proposals, ) @@ -1804,12 +1814,15 @@ class DetaForObjectDetection(DetaPreTrainedModel): self.post_init() @torch.jit.unused - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrForObjectDetection._set_aux_loss def _set_aux_loss(self, outputs_class, outputs_coord): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. - return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + aux_loss = [ + {"logits": logits, "pred_boxes": pred_boxes} + for logits, pred_boxes in zip(outputs_class.transpose(0, 1)[:-1], outputs_coord.transpose(0, 1)[:-1]) + ] + return aux_loss @add_start_docstrings_to_model_forward(DETA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=DetaObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) @@ -1929,21 +1942,25 @@ class DetaForObjectDetection(DetaPreTrainedModel): focal_alpha=self.config.focal_alpha, losses=losses, num_queries=self.config.num_queries, + assign_first_stage=self.config.assign_first_stage, + assign_second_stage=self.config.assign_second_stage, ) criterion.to(logits.device) # Third: compute the losses, based on outputs and labels outputs_loss = {} outputs_loss["logits"] = logits outputs_loss["pred_boxes"] = pred_boxes + outputs_loss["init_reference"] = init_reference if self.config.auxiliary_loss: - intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4] - outputs_class = self.class_embed(intermediate) - outputs_coord = self.bbox_embed(intermediate).sigmoid() auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord) outputs_loss["auxiliary_outputs"] = auxiliary_outputs if self.config.two_stage: enc_outputs_coord = outputs.enc_outputs_coord_logits.sigmoid() - outputs["enc_outputs"] = {"pred_logits": outputs.enc_outputs_class, "pred_boxes": enc_outputs_coord} + outputs_loss["enc_outputs"] = { + "logits": outputs.enc_outputs_class, + "pred_boxes": enc_outputs_coord, + "anchors": outputs.output_proposals.sigmoid(), + } loss_dict = criterion(outputs_loss, labels) # Fourth: compute total loss, as a weighted sum of the various losses @@ -1953,6 +1970,7 @@ class DetaForObjectDetection(DetaPreTrainedModel): aux_weight_dict = {} for i in range(self.config.decoder_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) @@ -1983,6 +2001,7 @@ class DetaForObjectDetection(DetaPreTrainedModel): init_reference_points=outputs.init_reference_points, enc_outputs_class=outputs.enc_outputs_class, enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + output_proposals=outputs.output_proposals, ) return dict_outputs @@ -2192,7 +2211,7 @@ class DetaLoss(nn.Module): List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the losses applied, see each loss' doc. """ - outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} + outputs_without_aux = {k: v for k, v in outputs.items() if k not in ("auxiliary_outputs", "enc_outputs")} # Retrieve the matching between the outputs of the last layer and the targets if self.assign_second_stage: @@ -2203,11 +2222,12 @@ class DetaLoss(nn.Module): # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - # (Niels): comment out function below, distributed training to be added - # if is_dist_avail_and_initialized(): - # torch.distributed.all_reduce(num_boxes) - # (Niels) in original implementation, num_boxes is divided by get_world_size() - num_boxes = torch.clamp(num_boxes, min=1).item() + # Check that we have initialized the distributed state + world_size = 1 + if PartialState._shared_state != {}: + num_boxes = reduce(num_boxes) + world_size = PartialState().num_processes + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() # Compute all the requested losses losses = {} @@ -2228,17 +2248,13 @@ class DetaLoss(nn.Module): enc_outputs = outputs["enc_outputs"] bin_targets = copy.deepcopy(targets) for bt in bin_targets: - bt["labels"] = torch.zeros_like(bt["labels"]) + bt["class_labels"] = torch.zeros_like(bt["class_labels"]) if self.assign_first_stage: indices = self.stg1_assigner(enc_outputs, bin_targets) else: indices = self.matcher(enc_outputs, bin_targets) for loss in self.losses: - kwargs = {} - if loss == "labels": - # Logging is enabled only for the last layer - kwargs["log"] = False - l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes) l_dict = {k + "_enc": v for k, v in l_dict.items()} losses.update(l_dict) @@ -2662,7 +2678,7 @@ class DetaStage2Assigner(nn.Module): sampled_idxs, sampled_gt_classes, ) = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label] - matched_idxs, matched_labels, targets[b]["labels"] + matched_idxs, matched_labels, targets[b]["class_labels"] ) pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label] pos_gt_inds = matched_idxs[pos_pr_inds] @@ -2727,7 +2743,7 @@ class DetaStage1Assigner(nn.Module): ) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow] matched_labels = self._subsample_labels(matched_labels) - all_pr_inds = torch.arange(len(anchors)) + all_pr_inds = torch.arange(len(anchors), device=matched_labels.device) pos_pr_inds = all_pr_inds[matched_labels == 1] pos_gt_inds = matched_idxs[pos_pr_inds] pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou) diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index 8581723ccb..8db3485703 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -57,14 +57,17 @@ class DetaModelTester: hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, num_queries=12, + two_stage_num_proposals=12, num_channels=3, - image_size=196, + image_size=224, n_targets=8, num_labels=91, num_feature_levels=4, encoder_n_points=2, decoder_n_points=6, - two_stage=False, + two_stage=True, + assign_first_stage=True, + assign_second_stage=True, ): self.parent = parent self.batch_size = batch_size @@ -78,6 +81,7 @@ class DetaModelTester: self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.num_queries = num_queries + self.two_stage_num_proposals = two_stage_num_proposals self.num_channels = num_channels self.image_size = image_size self.n_targets = n_targets @@ -86,6 +90,8 @@ class DetaModelTester: self.encoder_n_points = encoder_n_points self.decoder_n_points = decoder_n_points self.two_stage = two_stage + self.assign_first_stage = assign_first_stage + self.assign_second_stage = assign_second_stage # we also set the expected seq length for both encoder and decoder self.encoder_seq_length = ( @@ -96,7 +102,7 @@ class DetaModelTester: ) self.decoder_seq_length = self.num_queries - def prepare_config_and_inputs(self): + def prepare_config_and_inputs(self, model_class_name): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) @@ -114,10 +120,10 @@ class DetaModelTester: target["masks"] = torch.rand(self.n_targets, self.image_size, self.image_size, device=torch_device) labels.append(target) - config = self.get_config() + config = self.get_config(model_class_name) return config, pixel_values, pixel_mask, labels - def get_config(self): + def get_config(self, model_class_name): resnet_config = ResNetConfig( num_channels=3, embeddings_size=10, @@ -128,6 +134,9 @@ class DetaModelTester: out_features=["stage2", "stage3", "stage4"], out_indices=[2, 3, 4], ) + two_stage = model_class_name == "DetaForObjectDetection" + assign_first_stage = model_class_name == "DetaForObjectDetection" + assign_second_stage = model_class_name == "DetaForObjectDetection" return DetaConfig( d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, @@ -139,16 +148,19 @@ class DetaModelTester: dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, num_queries=self.num_queries, + two_stage_num_proposals=self.two_stage_num_proposals, num_labels=self.num_labels, num_feature_levels=self.num_feature_levels, encoder_n_points=self.encoder_n_points, decoder_n_points=self.decoder_n_points, - two_stage=self.two_stage, + two_stage=two_stage, + assign_first_stage=assign_first_stage, + assign_second_stage=assign_second_stage, backbone_config=resnet_config, ) - def prepare_config_and_inputs_for_common(self): - config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs(model_class_name) inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} return config, inputs_dict @@ -190,14 +202,14 @@ class DetaModelTester: result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) result = model(pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) - self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4)) result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) self.parent.assertEqual(result.loss.shape, ()) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) - self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.two_stage_num_proposals, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.two_stage_num_proposals, 4)) @require_torchvision @@ -267,19 +279,19 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin self.config_tester.check_config_can_be_init_without_params() def test_deta_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel") self.model_tester.create_and_check_deta_model(*config_and_inputs) def test_deta_freeze_backbone(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel") self.model_tester.create_and_check_deta_freeze_backbone(*config_and_inputs) def test_deta_unfreeze_backbone(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaModel") self.model_tester.create_and_check_deta_unfreeze_backbone(*config_and_inputs) def test_deta_object_detection_head_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs = self.model_tester.prepare_config_and_inputs(model_class_name="DetaForObjectDetection") self.model_tester.create_and_check_deta_object_detection_head_model(*config_and_inputs) @unittest.skip(reason="DETA does not use inputs_embeds")