SuperPointModel -> SuperPointForKeypointDetection (#29757)

This commit is contained in:
amyeroberts
2024-03-20 15:41:03 +00:00
committed by GitHub
parent 1248f09252
commit 3c17c529cc
11 changed files with 63 additions and 30 deletions

View File

@@ -250,6 +250,10 @@ The following auto classes are available for the following computer vision tasks
[[autodoc]] AutoModelForVideoClassification
### AutoModelForKeypointDetection
[[autodoc]] AutoModelForKeypointDetection
### AutoModelForMaskedImageModeling
[[autodoc]] AutoModelForMaskedImageModeling

View File

@@ -113,10 +113,8 @@ The original code can be found [here](https://github.com/magicleap/SuperPointPre
- preprocess
## SuperPointModel
## SuperPointForKeypointDetection
[[autodoc]] SuperPointModel
[[autodoc]] SuperPointForKeypointDetection
- forward

View File

@@ -1487,6 +1487,7 @@ else:
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MASK_GENERATION_MAPPING",
@@ -1527,6 +1528,7 @@ else:
"AutoModelForImageSegmentation",
"AutoModelForImageToImage",
"AutoModelForInstanceSegmentation",
"AutoModelForKeypointDetection",
"AutoModelForMaskedImageModeling",
"AutoModelForMaskedLM",
"AutoModelForMaskGeneration",
@@ -3341,7 +3343,7 @@ else:
_import_structure["models.superpoint"].extend(
[
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SuperPointModel",
"SuperPointForKeypointDetection",
"SuperPointPreTrainedModel",
]
)
@@ -6319,6 +6321,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
@@ -6359,6 +6362,7 @@ if TYPE_CHECKING:
AutoModelForImageSegmentation,
AutoModelForImageToImage,
AutoModelForInstanceSegmentation,
AutoModelForKeypointDetection,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMaskGeneration,
@@ -7852,7 +7856,7 @@ if TYPE_CHECKING:
)
from .models.superpoint import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointForKeypointDetection,
SuperPointPreTrainedModel,
)
from .models.swiftformer import (

View File

@@ -52,6 +52,7 @@ else:
"MODEL_FOR_IMAGE_MAPPING",
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
@@ -92,6 +93,7 @@ else:
"AutoModelForImageSegmentation",
"AutoModelForImageToImage",
"AutoModelForInstanceSegmentation",
"AutoModelForKeypointDetection",
"AutoModelForMaskGeneration",
"AutoModelForTextEncoding",
"AutoModelForMaskedImageModeling",
@@ -117,7 +119,6 @@ else:
"AutoModelWithLMHead",
"AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection",
"AutoModelForKeypointDetection",
]
try:
@@ -239,6 +240,7 @@ if TYPE_CHECKING:
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
MODEL_FOR_MASK_GENERATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,

View File

@@ -207,7 +207,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("squeezebert", "SqueezeBertModel"),
("stablelm", "StableLmModel"),
("starcoder2", "Starcoder2Model"),
("superpoint", "SuperPointModel"),
("swiftformer", "SwiftFormerModel"),
("swin", "SwinModel"),
("swin2sr", "Swin2SRModel"),
@@ -1225,6 +1224,14 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
[
("superpoint", "SuperPointForKeypointDetection"),
]
)
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "AlbertModel"),
@@ -1360,6 +1367,10 @@ MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BA
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
)
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
@@ -1377,6 +1388,10 @@ class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
class AutoModelForKeypointDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
class AutoModelForTextEncoding(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING

View File

@@ -40,7 +40,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modeling_superpoint"] = [
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SuperPointModel",
"SuperPointForKeypointDetection",
"SuperPointPreTrainedModel",
]
@@ -67,7 +67,7 @@ if TYPE_CHECKING:
else:
from .modeling_superpoint import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointForKeypointDetection,
SuperPointPreTrainedModel,
)

View File

@@ -26,7 +26,7 @@ SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class SuperPointConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SuperPointModel`]. It is used to instantiate a
This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a
SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the SuperPoint
[magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
@@ -53,12 +53,12 @@ class SuperPointConfig(PretrainedConfig):
Example:
```python
>>> from transformers import SuperPointConfig, SuperPointModel
>>> from transformers import SuperPointConfig, SuperPointForKeypointDetection
>>> # Initializing a SuperPoint superpoint style configuration
>>> configuration = SuperPointConfig()
>>> # Initializing a model from the superpoint style configuration
>>> model = SuperPointModel(configuration)
>>> model = SuperPointForKeypointDetection(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

View File

@@ -18,7 +18,7 @@ import requests
import torch
from PIL import Image
from transformers import SuperPointConfig, SuperPointImageProcessor, SuperPointModel
from transformers import SuperPointConfig, SuperPointForKeypointDetection, SuperPointImageProcessor
def get_superpoint_config():
@@ -106,7 +106,7 @@ def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save
rename_key(new_state_dict, src, dest)
# Load HuggingFace model
model = SuperPointModel(config)
model = SuperPointForKeypointDetection(config)
model.load_state_dict(new_state_dict)
model.eval()
print("Successfully loaded weights in the model")

View File

@@ -390,7 +390,7 @@ Args:
"SuperPoint model outputting keypoints and descriptors.",
SUPERPOINT_START_DOCSTRING,
)
class SuperPointModel(SuperPointPreTrainedModel):
class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
"""
SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a
SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and

View File

@@ -606,6 +606,9 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = None
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = None
MODEL_FOR_MASK_GENERATION_MAPPING = None
@@ -778,6 +781,13 @@ class AutoModelForInstanceSegmentation(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AutoModelForKeypointDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
@@ -8029,7 +8039,7 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject):
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SuperPointModel(metaclass=DummyObject):
class SuperPointForKeypointDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):

View File

@@ -28,7 +28,7 @@ if is_torch_available():
from transformers import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointForKeypointDetection,
)
if is_vision_available():
@@ -86,7 +86,7 @@ class SuperPointModelTester:
)
def create_and_check_model(self, config, pixel_values):
model = SuperPointModel(config=config)
model = SuperPointForKeypointDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
@@ -109,7 +109,7 @@ class SuperPointModelTester:
@require_torch
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (SuperPointModel,) if is_torch_available() else ()
all_model_classes = (SuperPointForKeypointDetection,) if is_torch_available() else ()
all_generative_model_classes = () if is_torch_available() else ()
fx_compatible = False
@@ -134,31 +134,31 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="SuperPointModel does not use inputs_embeds")
@unittest.skip(reason="SuperPointForKeypointDetection does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SuperPointModel does not support input and output embeddings")
@unittest.skip(reason="SuperPointForKeypointDetection does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="SuperPointModel does not use feedforward chunking")
@unittest.skip(reason="SuperPointForKeypointDetection does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@@ -219,7 +219,7 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SuperPointModel.from_pretrained(model_name)
model = SuperPointForKeypointDetection.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_forward_labels_should_be_none(self):
@@ -254,7 +254,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model = SuperPointModel.from_pretrained("magic-leap-community/superpoint").to(torch_device)
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint").to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)