Fix test_inference_instance_segmentation_head (#17872)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -387,9 +387,12 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
|
||||
)
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
|
||||
).to(torch_device)
|
||||
expected_slice = [
|
||||
[-1.3737124, -1.7724937, -1.9364233],
|
||||
[-1.5977281, -1.9867939, -2.1523695],
|
||||
[-1.5795398, -1.9269832, -2.093942],
|
||||
]
|
||||
expected_slice = torch.tensor(expected_slice).to(torch_device)
|
||||
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
# class_queries_logits
|
||||
class_queries_logits = outputs.class_queries_logits
|
||||
|
||||
Reference in New Issue
Block a user