SuperPointModel -> SuperPointForKeypointDetection (#29757)
This commit is contained in:
committed by
ArthurZucker
parent
1112b24394
commit
8049122576
@@ -250,6 +250,10 @@ The following auto classes are available for the following computer vision tasks
|
|||||||
|
|
||||||
[[autodoc]] AutoModelForVideoClassification
|
[[autodoc]] AutoModelForVideoClassification
|
||||||
|
|
||||||
|
### AutoModelForKeypointDetection
|
||||||
|
|
||||||
|
[[autodoc]] AutoModelForKeypointDetection
|
||||||
|
|
||||||
### AutoModelForMaskedImageModeling
|
### AutoModelForMaskedImageModeling
|
||||||
|
|
||||||
[[autodoc]] AutoModelForMaskedImageModeling
|
[[autodoc]] AutoModelForMaskedImageModeling
|
||||||
|
|||||||
@@ -113,10 +113,8 @@ The original code can be found [here](https://github.com/magicleap/SuperPointPre
|
|||||||
|
|
||||||
- preprocess
|
- preprocess
|
||||||
|
|
||||||
## SuperPointModel
|
## SuperPointForKeypointDetection
|
||||||
|
|
||||||
[[autodoc]] SuperPointModel
|
[[autodoc]] SuperPointForKeypointDetection
|
||||||
|
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1487,6 +1487,7 @@ else:
|
|||||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
|
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
|
||||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||||
|
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
|
||||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
"MODEL_FOR_MASK_GENERATION_MAPPING",
|
"MODEL_FOR_MASK_GENERATION_MAPPING",
|
||||||
@@ -1527,6 +1528,7 @@ else:
|
|||||||
"AutoModelForImageSegmentation",
|
"AutoModelForImageSegmentation",
|
||||||
"AutoModelForImageToImage",
|
"AutoModelForImageToImage",
|
||||||
"AutoModelForInstanceSegmentation",
|
"AutoModelForInstanceSegmentation",
|
||||||
|
"AutoModelForKeypointDetection",
|
||||||
"AutoModelForMaskedImageModeling",
|
"AutoModelForMaskedImageModeling",
|
||||||
"AutoModelForMaskedLM",
|
"AutoModelForMaskedLM",
|
||||||
"AutoModelForMaskGeneration",
|
"AutoModelForMaskGeneration",
|
||||||
@@ -3341,7 +3343,7 @@ else:
|
|||||||
_import_structure["models.superpoint"].extend(
|
_import_structure["models.superpoint"].extend(
|
||||||
[
|
[
|
||||||
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"SuperPointModel",
|
"SuperPointForKeypointDetection",
|
||||||
"SuperPointPreTrainedModel",
|
"SuperPointPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -6319,6 +6321,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
|
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
|
||||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||||
|
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
|
||||||
MODEL_FOR_MASK_GENERATION_MAPPING,
|
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
@@ -6359,6 +6362,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForImageSegmentation,
|
AutoModelForImageSegmentation,
|
||||||
AutoModelForImageToImage,
|
AutoModelForImageToImage,
|
||||||
AutoModelForInstanceSegmentation,
|
AutoModelForInstanceSegmentation,
|
||||||
|
AutoModelForKeypointDetection,
|
||||||
AutoModelForMaskedImageModeling,
|
AutoModelForMaskedImageModeling,
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForMaskGeneration,
|
AutoModelForMaskGeneration,
|
||||||
@@ -7852,7 +7856,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.superpoint import (
|
from .models.superpoint import (
|
||||||
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
SuperPointModel,
|
SuperPointForKeypointDetection,
|
||||||
SuperPointPreTrainedModel,
|
SuperPointPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.swiftformer import (
|
from .models.swiftformer import (
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ else:
|
|||||||
"MODEL_FOR_IMAGE_MAPPING",
|
"MODEL_FOR_IMAGE_MAPPING",
|
||||||
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
|
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
|
||||||
|
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
|
||||||
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
"MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING",
|
||||||
"MODEL_FOR_MASKED_LM_MAPPING",
|
"MODEL_FOR_MASKED_LM_MAPPING",
|
||||||
@@ -92,6 +93,7 @@ else:
|
|||||||
"AutoModelForImageSegmentation",
|
"AutoModelForImageSegmentation",
|
||||||
"AutoModelForImageToImage",
|
"AutoModelForImageToImage",
|
||||||
"AutoModelForInstanceSegmentation",
|
"AutoModelForInstanceSegmentation",
|
||||||
|
"AutoModelForKeypointDetection",
|
||||||
"AutoModelForMaskGeneration",
|
"AutoModelForMaskGeneration",
|
||||||
"AutoModelForTextEncoding",
|
"AutoModelForTextEncoding",
|
||||||
"AutoModelForMaskedImageModeling",
|
"AutoModelForMaskedImageModeling",
|
||||||
@@ -117,7 +119,6 @@ else:
|
|||||||
"AutoModelWithLMHead",
|
"AutoModelWithLMHead",
|
||||||
"AutoModelForZeroShotImageClassification",
|
"AutoModelForZeroShotImageClassification",
|
||||||
"AutoModelForZeroShotObjectDetection",
|
"AutoModelForZeroShotObjectDetection",
|
||||||
"AutoModelForKeypointDetection",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -239,6 +240,7 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
|
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
|
||||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
|
||||||
|
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
|
||||||
MODEL_FOR_MASK_GENERATION_MAPPING,
|
MODEL_FOR_MASK_GENERATION_MAPPING,
|
||||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||||
MODEL_FOR_MASKED_LM_MAPPING,
|
MODEL_FOR_MASKED_LM_MAPPING,
|
||||||
|
|||||||
@@ -207,7 +207,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("squeezebert", "SqueezeBertModel"),
|
("squeezebert", "SqueezeBertModel"),
|
||||||
("stablelm", "StableLmModel"),
|
("stablelm", "StableLmModel"),
|
||||||
("starcoder2", "Starcoder2Model"),
|
("starcoder2", "Starcoder2Model"),
|
||||||
("superpoint", "SuperPointModel"),
|
|
||||||
("swiftformer", "SwiftFormerModel"),
|
("swiftformer", "SwiftFormerModel"),
|
||||||
("swin", "SwinModel"),
|
("swin", "SwinModel"),
|
||||||
("swin2sr", "Swin2SRModel"),
|
("swin2sr", "Swin2SRModel"),
|
||||||
@@ -1225,6 +1224,14 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
("superpoint", "SuperPointForKeypointDetection"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
("albert", "AlbertModel"),
|
("albert", "AlbertModel"),
|
||||||
@@ -1360,6 +1367,10 @@ MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BA
|
|||||||
|
|
||||||
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
||||||
|
|
||||||
|
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
||||||
|
|
||||||
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
@@ -1377,6 +1388,10 @@ class AutoModelForMaskGeneration(_BaseAutoModelClass):
|
|||||||
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForKeypointDetection(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForTextEncoding(_BaseAutoModelClass):
|
class AutoModelForTextEncoding(_BaseAutoModelClass):
|
||||||
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modeling_superpoint"] = [
|
_import_structure["modeling_superpoint"] = [
|
||||||
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"SuperPointModel",
|
"SuperPointForKeypointDetection",
|
||||||
"SuperPointPreTrainedModel",
|
"SuperPointPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .modeling_superpoint import (
|
from .modeling_superpoint import (
|
||||||
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
SuperPointModel,
|
SuperPointForKeypointDetection,
|
||||||
SuperPointPreTrainedModel,
|
SuperPointPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ SUPERPOINT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
|
|
||||||
class SuperPointConfig(PretrainedConfig):
|
class SuperPointConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`SuperPointModel`]. It is used to instantiate a
|
This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a
|
||||||
SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
|
SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
configuration with the defaults will yield a similar configuration to that of the SuperPoint
|
configuration with the defaults will yield a similar configuration to that of the SuperPoint
|
||||||
[magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
|
[magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
|
||||||
@@ -53,12 +53,12 @@ class SuperPointConfig(PretrainedConfig):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
>>> from transformers import SuperPointConfig, SuperPointModel
|
>>> from transformers import SuperPointConfig, SuperPointForKeypointDetection
|
||||||
|
|
||||||
>>> # Initializing a SuperPoint superpoint style configuration
|
>>> # Initializing a SuperPoint superpoint style configuration
|
||||||
>>> configuration = SuperPointConfig()
|
>>> configuration = SuperPointConfig()
|
||||||
>>> # Initializing a model from the superpoint style configuration
|
>>> # Initializing a model from the superpoint style configuration
|
||||||
>>> model = SuperPointModel(configuration)
|
>>> model = SuperPointForKeypointDetection(configuration)
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
```"""
|
```"""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from transformers import SuperPointConfig, SuperPointImageProcessor, SuperPointModel
|
from transformers import SuperPointConfig, SuperPointForKeypointDetection, SuperPointImageProcessor
|
||||||
|
|
||||||
|
|
||||||
def get_superpoint_config():
|
def get_superpoint_config():
|
||||||
@@ -106,7 +106,7 @@ def convert_superpoint_checkpoint(checkpoint_url, pytorch_dump_folder_path, save
|
|||||||
rename_key(new_state_dict, src, dest)
|
rename_key(new_state_dict, src, dest)
|
||||||
|
|
||||||
# Load HuggingFace model
|
# Load HuggingFace model
|
||||||
model = SuperPointModel(config)
|
model = SuperPointForKeypointDetection(config)
|
||||||
model.load_state_dict(new_state_dict)
|
model.load_state_dict(new_state_dict)
|
||||||
model.eval()
|
model.eval()
|
||||||
print("Successfully loaded weights in the model")
|
print("Successfully loaded weights in the model")
|
||||||
|
|||||||
@@ -390,7 +390,7 @@ Args:
|
|||||||
"SuperPoint model outputting keypoints and descriptors.",
|
"SuperPoint model outputting keypoints and descriptors.",
|
||||||
SUPERPOINT_START_DOCSTRING,
|
SUPERPOINT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class SuperPointModel(SuperPointPreTrainedModel):
|
class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
|
||||||
"""
|
"""
|
||||||
SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a
|
SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a
|
||||||
SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and
|
SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and
|
||||||
|
|||||||
@@ -606,6 +606,9 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = None
|
|||||||
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_MASK_GENERATION_MAPPING = None
|
MODEL_FOR_MASK_GENERATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
@@ -778,6 +781,13 @@ class AutoModelForInstanceSegmentation(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForKeypointDetection(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForMaskedImageModeling(metaclass=DummyObject):
|
class AutoModelForMaskedImageModeling(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -8029,7 +8039,7 @@ class Starcoder2PreTrainedModel(metaclass=DummyObject):
|
|||||||
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
class SuperPointModel(metaclass=DummyObject):
|
class SuperPointForKeypointDetection(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
SuperPointModel,
|
SuperPointForKeypointDetection,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
@@ -86,7 +86,7 @@ class SuperPointModelTester:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values):
|
def create_and_check_model(self, config, pixel_values):
|
||||||
model = SuperPointModel(config=config)
|
model = SuperPointForKeypointDetection(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(pixel_values)
|
result = model(pixel_values)
|
||||||
@@ -109,7 +109,7 @@ class SuperPointModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (SuperPointModel,) if is_torch_available() else ()
|
all_model_classes = (SuperPointForKeypointDetection,) if is_torch_available() else ()
|
||||||
all_generative_model_classes = () if is_torch_available() else ()
|
all_generative_model_classes = () if is_torch_available() else ()
|
||||||
|
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
@@ -134,31 +134,31 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def create_and_test_config_common_properties(self):
|
def create_and_test_config_common_properties(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointModel does not use inputs_embeds")
|
@unittest.skip(reason="SuperPointForKeypointDetection does not use inputs_embeds")
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointModel does not support input and output embeddings")
|
@unittest.skip(reason="SuperPointForKeypointDetection does not support input and output embeddings")
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointModel does not use feedforward chunking")
|
@unittest.skip(reason="SuperPointForKeypointDetection does not use feedforward chunking")
|
||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||||
def test_training(self):
|
def test_training(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||||
def test_training_gradient_checkpointing(self):
|
def test_training_gradient_checkpointing(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -219,7 +219,7 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
model = SuperPointModel.from_pretrained(model_name)
|
model = SuperPointForKeypointDetection.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def test_forward_labels_should_be_none(self):
|
def test_forward_labels_should_be_none(self):
|
||||||
@@ -254,7 +254,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
model = SuperPointModel.from_pretrained("magic-leap-community/superpoint").to(torch_device)
|
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint").to(torch_device)
|
||||||
preprocessor = self.default_image_processor
|
preprocessor = self.default_image_processor
|
||||||
images = prepare_imgs()
|
images = prepare_imgs()
|
||||||
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user