SuperPointModel -> SuperPointForKeypointDetection (#29757)
This commit is contained in:
@@ -28,7 +28,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SuperPointModel,
|
||||
SuperPointForKeypointDetection,
|
||||
)
|
||||
|
||||
if is_vision_available():
|
||||
@@ -86,7 +86,7 @@ class SuperPointModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = SuperPointModel(config=config)
|
||||
model = SuperPointForKeypointDetection(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
@@ -109,7 +109,7 @@ class SuperPointModelTester:
|
||||
|
||||
@require_torch
|
||||
class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SuperPointModel,) if is_torch_available() else ()
|
||||
all_model_classes = (SuperPointForKeypointDetection,) if is_torch_available() else ()
|
||||
all_generative_model_classes = () if is_torch_available() else ()
|
||||
|
||||
fx_compatible = False
|
||||
@@ -134,31 +134,31 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def create_and_test_config_common_properties(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="SuperPointModel does not use inputs_embeds")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointModel does not support input and output embeddings")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointModel does not use feedforward chunking")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection does not use feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SuperPointModel is not trainable")
|
||||
@unittest.skip(reason="SuperPointForKeypointDetection is not trainable")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@@ -219,7 +219,7 @@ class SuperPointModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in SUPERPOINT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = SuperPointModel.from_pretrained(model_name)
|
||||
model = SuperPointForKeypointDetection.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_forward_labels_should_be_none(self):
|
||||
@@ -254,7 +254,7 @@ class SuperPointModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model = SuperPointModel.from_pretrained("magic-leap-community/superpoint").to(torch_device)
|
||||
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint").to(torch_device)
|
||||
preprocessor = self.default_image_processor
|
||||
images = prepare_imgs()
|
||||
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user