Update output of SuperPointForKeypointDetection (#29809)
* Remove auto class * Update ImagePointDescriptionOutput * Update model outputs * Rename output class * Revert "Remove auto class" This reverts commit ed4a8f549d79cdb0cdf7aa74205a185c41471519. * Address comments
This commit is contained in:
@@ -79,7 +79,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImagePointDescriptionOutput(ModelOutput):
|
class SuperPointKeypointDescriptionOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
|
Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
|
||||||
keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
|
keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
|
||||||
@@ -88,8 +88,8 @@ class ImagePointDescriptionOutput(ModelOutput):
|
|||||||
and which are padding.
|
and which are padding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
Loss computed during training.
|
||||||
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
||||||
Relative (x, y) coordinates of predicted keypoints in a given image.
|
Relative (x, y) coordinates of predicted keypoints in a given image.
|
||||||
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
|
scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
|
||||||
@@ -105,7 +105,7 @@ class ImagePointDescriptionOutput(ModelOutput):
|
|||||||
(also called feature maps) of the model at the output of each stage.
|
(also called feature maps) of the model at the output of each stage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: torch.FloatTensor = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
keypoints: Optional[torch.IntTensor] = None
|
keypoints: Optional[torch.IntTensor] = None
|
||||||
scores: Optional[torch.FloatTensor] = None
|
scores: Optional[torch.FloatTensor] = None
|
||||||
descriptors: Optional[torch.FloatTensor] = None
|
descriptors: Optional[torch.FloatTensor] = None
|
||||||
@@ -414,11 +414,11 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
|||||||
@add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SUPERPOINT_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor = None,
|
pixel_values: torch.FloatTensor,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, ImagePointDescriptionOutput]:
|
) -> Union[Tuple, SuperPointKeypointDescriptionOutput]:
|
||||||
"""
|
"""
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@@ -437,20 +437,15 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
|||||||
>>> inputs = processor(image, return_tensors="pt")
|
>>> inputs = processor(image, return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
```"""
|
```"""
|
||||||
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
raise ValueError(
|
raise ValueError("SuperPoint does not support training for now.")
|
||||||
f"SuperPoint is not trainable, no labels should be provided.Therefore, labels should be None but were {type(labels)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if pixel_values is None:
|
|
||||||
raise ValueError("You have to specify pixel_values")
|
|
||||||
|
|
||||||
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
|
pixel_values = self.extract_one_channel_pixel_values(pixel_values)
|
||||||
|
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
@@ -493,12 +488,10 @@ class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = encoder_outputs[1] if output_hidden_states else None
|
hidden_states = encoder_outputs[1] if output_hidden_states else None
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
|
||||||
v for v in [last_hidden_state, keypoints, scores, descriptors, mask, hidden_states] if v is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImagePointDescriptionOutput(
|
return SuperPointKeypointDescriptionOutput(
|
||||||
last_hidden_state=last_hidden_state,
|
loss=loss,
|
||||||
keypoints=keypoints,
|
keypoints=keypoints,
|
||||||
scores=scores,
|
scores=scores,
|
||||||
descriptors=descriptors,
|
descriptors=descriptors,
|
||||||
|
|||||||
@@ -85,13 +85,17 @@ class SuperPointModelTester:
|
|||||||
border_removal_distance=self.border_removal_distance,
|
border_removal_distance=self.border_removal_distance,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values):
|
def create_and_check_keypoint_detection(self, config, pixel_values):
|
||||||
model = SuperPointForKeypointDetection(config=config)
|
model = SuperPointForKeypointDetection(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(result.keypoints.shape[0], self.batch_size)
|
||||||
|
self.parent.assertEqual(result.keypoints.shape[-1], 2)
|
||||||
|
|
||||||
|
result = model(pixel_values, output_hidden_states=True)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.last_hidden_state.shape,
|
result.hidden_states[-1].shape,
|
||||||
(
|
(
|
||||||
self.batch_size,
|
self.batch_size,
|
||||||
self.encoder_hidden_sizes[-1],
|
self.encoder_hidden_sizes[-1],
|
||||||
@@ -146,19 +150,19 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||||
def test_training(self):
|
def test_training(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||||
def test_training_gradient_checkpointing(self):
|
def test_training_gradient_checkpointing(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection does not support training")
|
||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -166,9 +170,9 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_retain_grad_hidden_states_attentions(self):
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_model(self):
|
def test_keypoint_detection(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_keypoint_detection(*config_and_inputs)
|
||||||
|
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs()
|
config, _ = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user