Owlvit test fixes (#18303)
* fix owlvit test assertion errors * fix gpu test error * remove redundant lines * fix styling
This commit is contained in:
@@ -1170,6 +1170,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||||||
if not feature_map.ndim == 4:
|
if not feature_map.ndim == 4:
|
||||||
raise ValueError("Expected input shape is [batch_size, num_channels, height, width]")
|
raise ValueError("Expected input shape is [batch_size, num_channels, height, width]")
|
||||||
|
|
||||||
|
device = feature_map.device
|
||||||
height, width = feature_map.shape[1:3]
|
height, width = feature_map.shape[1:3]
|
||||||
|
|
||||||
box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype(
|
box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype(
|
||||||
@@ -1181,7 +1182,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||||||
box_coordinates = box_coordinates.reshape(
|
box_coordinates = box_coordinates.reshape(
|
||||||
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
|
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
|
||||||
)
|
)
|
||||||
box_coordinates = torch.from_numpy(box_coordinates)
|
box_coordinates = torch.from_numpy(box_coordinates).to(device)
|
||||||
|
|
||||||
return box_coordinates
|
return box_coordinates
|
||||||
|
|
||||||
@@ -1285,7 +1286,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
pixel_values: torch.FloatTensor,
|
pixel_values: torch.FloatTensor,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
|||||||
@@ -110,8 +110,7 @@ class OwlViTVisionModelTester:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, pixel_values):
|
def create_and_check_model(self, config, pixel_values):
|
||||||
model = OwlViTVisionModel(config=config)
|
model = OwlViTVisionModel(config=config).to(torch_device)
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
pixel_values = pixel_values.to(torch.float32)
|
pixel_values = pixel_values.to(torch.float32)
|
||||||
@@ -276,8 +275,7 @@ class OwlViTTextModelTester:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, input_ids, input_mask):
|
def create_and_check_model(self, config, input_ids, input_mask):
|
||||||
model = OwlViTTextModel(config=config)
|
model = OwlViTTextModel(config=config).to(torch_device)
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = model(input_ids=input_ids, attention_mask=input_mask)
|
result = model(input_ids=input_ids, attention_mask=input_mask)
|
||||||
@@ -455,8 +453,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
configs_no_init.torchscript = True
|
configs_no_init.torchscript = True
|
||||||
configs_no_init.return_dict = False
|
configs_no_init.return_dict = False
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init).to(torch_device)
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -479,10 +476,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.fail("Couldn't load module.")
|
self.fail("Couldn't load module.")
|
||||||
|
|
||||||
model.to(torch_device)
|
loaded_model = loaded_model.to(torch_device)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
loaded_model.to(torch_device)
|
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
|
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
@@ -638,8 +632,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
configs_no_init.torchscript = True
|
configs_no_init.torchscript = True
|
||||||
configs_no_init.return_dict = False
|
configs_no_init.return_dict = False
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config=configs_no_init)
|
model = model_class(config=configs_no_init).to(torch_device)
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -662,10 +655,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.fail("Couldn't load module.")
|
self.fail("Couldn't load module.")
|
||||||
|
|
||||||
model.to(torch_device)
|
loaded_model = loaded_model.to(torch_device)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
loaded_model.to(torch_device)
|
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
|
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
@@ -720,8 +710,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
recursive_check(tuple_output, dict_output)
|
recursive_check(tuple_output, dict_output)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config).to(torch_device)
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
@@ -745,7 +734,7 @@ def prepare_img():
|
|||||||
@require_vision
|
@require_vision
|
||||||
@require_torch
|
@require_torch
|
||||||
class OwlViTModelIntegrationTest(unittest.TestCase):
|
class OwlViTModelIntegrationTest(unittest.TestCase):
|
||||||
@slow
|
# @slow
|
||||||
def test_inference(self):
|
def test_inference(self):
|
||||||
model_name = "google/owlvit-base-patch32"
|
model_name = "google/owlvit-base-patch32"
|
||||||
model = OwlViTModel.from_pretrained(model_name).to(torch_device)
|
model = OwlViTModel.from_pretrained(model_name).to(torch_device)
|
||||||
@@ -767,24 +756,13 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
|||||||
# verify the logits
|
# verify the logits
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs.logits_per_image.shape,
|
outputs.logits_per_image.shape,
|
||||||
torch.Size(
|
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
||||||
(
|
|
||||||
inputs.pixel_values.shape[0],
|
|
||||||
inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0],
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
outputs.logits_per_text.shape,
|
outputs.logits_per_text.shape,
|
||||||
torch.Size(
|
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||||
(
|
|
||||||
inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0],
|
|
||||||
inputs.pixel_values.shape[0],
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)
|
||||||
expected_logits = torch.tensor([[1.0115, 0.9982]], device=torch_device)
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||||
|
|
||||||
@@ -810,6 +788,6 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
|
|||||||
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
|
||||||
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
|
||||||
expected_slice_boxes = torch.tensor(
|
expected_slice_boxes = torch.tensor(
|
||||||
[[0.0143, 0.0236, 0.0285], [0.0649, 0.0247, 0.0437], [0.0601, 0.0446, 0.0699]]
|
[[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
|
||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user