From 713eab45d3dff1199b823d10b0bc833d835e91e2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 18 Oct 2022 15:00:39 +0200 Subject: [PATCH] :rotating_light: :rotating_light: :rotating_light: [Breaking change] Deformable DETR intermediate representations (#19678) * [Breaking change] Deformable DETR intermediate representations - Fixes naturally the `object-detection` pipeline. - Moves from `[n_decoders, batch_size, ...]` to `[batch_size, n_decoders, ...]` instead. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../modeling_deformable_detr.py | 36 ++++++++++--------- .../test_pipelines_object_detection.py | 6 ---- 2 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 79f0ce8b6e..db28187a09 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -135,9 +135,9 @@ class DeformableDetrDecoderOutput(ModelOutput): Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): Stacked intermediate hidden states (output of each layer of the decoder). - intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): Stacked intermediate reference points (reference points of each layer of the decoder). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of @@ -171,9 +171,9 @@ class DeformableDetrModelOutput(ModelOutput): Initial reference points sent through the Transformer decoder. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. - intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`): + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): Stacked intermediate hidden states (output of each layer of the decoder). - intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`): + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): Stacked intermediate reference points (reference points of each layer of the decoder). decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of @@ -266,9 +266,9 @@ class DeformableDetrObjectDetectionOutput(ModelOutput): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4, 4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`): + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): Stacked intermediate hidden states (output of each layer of the decoder). - intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`): + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): Stacked intermediate reference points (reference points of each layer of the decoder). init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): Initial reference points sent through the Transformer decoder. @@ -1390,8 +1390,9 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) - intermediate = torch.stack(intermediate) - intermediate_reference_points = torch.stack(intermediate_reference_points) + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) # add hidden states from the last decoder layer if output_hidden_states: @@ -1913,14 +1914,14 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): outputs_classes = [] outputs_coords = [] - for level in range(hidden_states.shape[0]): + for level in range(hidden_states.shape[1]): if level == 0: reference = init_reference else: - reference = inter_references[level - 1] + reference = inter_references[:, level - 1] reference = inverse_sigmoid(reference) - outputs_class = self.class_embed[level](hidden_states[level]) - delta_bbox = self.bbox_embed[level](hidden_states[level]) + outputs_class = self.class_embed[level](hidden_states[:, level]) + delta_bbox = self.bbox_embed[level](hidden_states[:, level]) if reference.shape[-1] == 4: outputs_coord_logits = delta_bbox + reference elif reference.shape[-1] == 2: @@ -1931,11 +1932,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): outputs_coord = outputs_coord_logits.sigmoid() outputs_classes.append(outputs_class) outputs_coords.append(outputs_coord) - outputs_class = torch.stack(outputs_classes) - outputs_coord = torch.stack(outputs_coords) + # Keep batch_size as first dimension + outputs_class = torch.stack(outputs_classes, dim=1) + outputs_coord = torch.stack(outputs_coords, dim=1) - logits = outputs_class[-1] - pred_boxes = outputs_coord[-1] + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] loss, loss_dict, auxiliary_outputs = None, None, None if labels is not None: @@ -2000,8 +2002,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, intermediate_hidden_states=outputs.intermediate_hidden_states, - init_reference_points=outputs.init_reference_points, intermediate_reference_points=outputs.intermediate_reference_points, + init_reference_points=outputs.init_reference_points, enc_outputs_class=outputs.enc_outputs_class, enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, ) diff --git a/tests/pipelines/test_pipelines_object_detection.py b/tests/pipelines/test_pipelines_object_detection.py index ebefcaab61..196f4c82ac 100644 --- a/tests/pipelines/test_pipelines_object_detection.py +++ b/tests/pipelines/test_pipelines_object_detection.py @@ -44,12 +44,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING def get_test_pipeline(self, model, tokenizer, feature_extractor): - if model.__class__.__name__ == "DeformableDetrForObjectDetection": - self.skipTest( - """Deformable DETR requires a custom CUDA kernel. - """ - ) - object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor) return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]