🚨 🚨 🚨 [Breaking change] Deformable DETR intermediate representations (#19678)
* [Breaking change] Deformable DETR intermediate representations - Fixes naturally the `object-detection` pipeline. - Moves from `[n_decoders, batch_size, ...]` to `[batch_size, n_decoders, ...]` instead. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -135,9 +135,9 @@ class DeformableDetrDecoderOutput(ModelOutput):
|
|||||||
Args:
|
Args:
|
||||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
||||||
Stacked intermediate hidden states (output of each layer of the decoder).
|
Stacked intermediate hidden states (output of each layer of the decoder).
|
||||||
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
|
||||||
Stacked intermediate reference points (reference points of each layer of the decoder).
|
Stacked intermediate reference points (reference points of each layer of the decoder).
|
||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
@@ -171,9 +171,9 @@ class DeformableDetrModelOutput(ModelOutput):
|
|||||||
Initial reference points sent through the Transformer decoder.
|
Initial reference points sent through the Transformer decoder.
|
||||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||||
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`):
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
||||||
Stacked intermediate hidden states (output of each layer of the decoder).
|
Stacked intermediate hidden states (output of each layer of the decoder).
|
||||||
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`):
|
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
||||||
Stacked intermediate reference points (reference points of each layer of the decoder).
|
Stacked intermediate reference points (reference points of each layer of the decoder).
|
||||||
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
@@ -266,9 +266,9 @@ class DeformableDetrObjectDetectionOutput(ModelOutput):
|
|||||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4,
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_heads, 4,
|
||||||
4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average
|
4)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average
|
||||||
in the self-attention heads.
|
in the self-attention heads.
|
||||||
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`):
|
intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
|
||||||
Stacked intermediate hidden states (output of each layer of the decoder).
|
Stacked intermediate hidden states (output of each layer of the decoder).
|
||||||
intermediate_reference_points (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, 4)`):
|
intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
|
||||||
Stacked intermediate reference points (reference points of each layer of the decoder).
|
Stacked intermediate reference points (reference points of each layer of the decoder).
|
||||||
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
|
||||||
Initial reference points sent through the Transformer decoder.
|
Initial reference points sent through the Transformer decoder.
|
||||||
@@ -1390,8 +1390,9 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
|||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
all_cross_attentions += (layer_outputs[2],)
|
all_cross_attentions += (layer_outputs[2],)
|
||||||
|
|
||||||
intermediate = torch.stack(intermediate)
|
# Keep batch_size as first dimension
|
||||||
intermediate_reference_points = torch.stack(intermediate_reference_points)
|
intermediate = torch.stack(intermediate, dim=1)
|
||||||
|
intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
# add hidden states from the last decoder layer
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
@@ -1913,14 +1914,14 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|||||||
outputs_classes = []
|
outputs_classes = []
|
||||||
outputs_coords = []
|
outputs_coords = []
|
||||||
|
|
||||||
for level in range(hidden_states.shape[0]):
|
for level in range(hidden_states.shape[1]):
|
||||||
if level == 0:
|
if level == 0:
|
||||||
reference = init_reference
|
reference = init_reference
|
||||||
else:
|
else:
|
||||||
reference = inter_references[level - 1]
|
reference = inter_references[:, level - 1]
|
||||||
reference = inverse_sigmoid(reference)
|
reference = inverse_sigmoid(reference)
|
||||||
outputs_class = self.class_embed[level](hidden_states[level])
|
outputs_class = self.class_embed[level](hidden_states[:, level])
|
||||||
delta_bbox = self.bbox_embed[level](hidden_states[level])
|
delta_bbox = self.bbox_embed[level](hidden_states[:, level])
|
||||||
if reference.shape[-1] == 4:
|
if reference.shape[-1] == 4:
|
||||||
outputs_coord_logits = delta_bbox + reference
|
outputs_coord_logits = delta_bbox + reference
|
||||||
elif reference.shape[-1] == 2:
|
elif reference.shape[-1] == 2:
|
||||||
@@ -1931,11 +1932,12 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|||||||
outputs_coord = outputs_coord_logits.sigmoid()
|
outputs_coord = outputs_coord_logits.sigmoid()
|
||||||
outputs_classes.append(outputs_class)
|
outputs_classes.append(outputs_class)
|
||||||
outputs_coords.append(outputs_coord)
|
outputs_coords.append(outputs_coord)
|
||||||
outputs_class = torch.stack(outputs_classes)
|
# Keep batch_size as first dimension
|
||||||
outputs_coord = torch.stack(outputs_coords)
|
outputs_class = torch.stack(outputs_classes, dim=1)
|
||||||
|
outputs_coord = torch.stack(outputs_coords, dim=1)
|
||||||
|
|
||||||
logits = outputs_class[-1]
|
logits = outputs_class[:, -1]
|
||||||
pred_boxes = outputs_coord[-1]
|
pred_boxes = outputs_coord[:, -1]
|
||||||
|
|
||||||
loss, loss_dict, auxiliary_outputs = None, None, None
|
loss, loss_dict, auxiliary_outputs = None, None, None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@@ -2000,8 +2002,8 @@ class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
|||||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
intermediate_hidden_states=outputs.intermediate_hidden_states,
|
intermediate_hidden_states=outputs.intermediate_hidden_states,
|
||||||
init_reference_points=outputs.init_reference_points,
|
|
||||||
intermediate_reference_points=outputs.intermediate_reference_points,
|
intermediate_reference_points=outputs.intermediate_reference_points,
|
||||||
|
init_reference_points=outputs.init_reference_points,
|
||||||
enc_outputs_class=outputs.enc_outputs_class,
|
enc_outputs_class=outputs.enc_outputs_class,
|
||||||
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -44,12 +44,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
|
|||||||
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||||
|
|
||||||
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
def get_test_pipeline(self, model, tokenizer, feature_extractor):
|
||||||
if model.__class__.__name__ == "DeformableDetrForObjectDetection":
|
|
||||||
self.skipTest(
|
|
||||||
"""Deformable DETR requires a custom CUDA kernel.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
|
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user