SuperPointModel -> SuperPointForKeypointDetection (#29757)

This commit is contained in:
amyeroberts
2024-03-20 15:41:03 +00:00
committed by GitHub
parent 1248f09252
commit 3c17c529cc
11 changed files with 63 additions and 30 deletions

View File

@@ -28,7 +28,7 @@ if is_torch_available():
from transformers import (
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
SuperPointModel,
SuperPointForKeypointDetection,
)
if is_vision_available():
@@ -86,7 +86,7 @@ class SuperPointModelTester:
)
def create_and_check_model(self, config, pixel_values):
model = SuperPointModel(config=config)
model = SuperPointForKeypointDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
@@ -109,7 +109,7 @@ class SuperPointModelTester:
@require_torch
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (SuperPointModel,) if is_torch_available() else ()
all_model_classes = (SuperPointForKeypointDetection,) if is_torch_available() else ()
all_generative_model_classes = () if is_torch_available() else ()
fx_compatible = False
@@ -134,31 +134,31 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="SuperPointModel does not use inputs_embeds")
@unittest.skip(reason="SuperPointForKeypointDetection does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SuperPointModel does not support input and output embeddings")
@unittest.skip(reason="SuperPointForKeypointDetection does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="SuperPointModel does not use feedforward chunking")
@unittest.skip(reason="SuperPointForKeypointDetection does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SuperPointModel is not trainable")
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@@ -219,7 +219,7 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SuperPointModel.from_pretrained(model_name)
model = SuperPointForKeypointDetection.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_forward_labels_should_be_none(self):
@@ -254,7 +254,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model = SuperPointModel.from_pretrained("magic-leap-community/superpoint").to(torch_device)
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint").to(torch_device)
preprocessor = self.default_image_processor
images = prepare_imgs()
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)