From b03be78a4bc3223695ed4f738375148acd487007 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 24 Jun 2022 19:36:45 +0200 Subject: [PATCH] Fix `test_inference_instance_segmentation_head` (#17872) Co-authored-by: ydshieh --- tests/models/maskformer/test_modeling_maskformer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/maskformer/test_modeling_maskformer.py b/tests/models/maskformer/test_modeling_maskformer.py index 1c64ca46a5..b1e6121061 100644 --- a/tests/models/maskformer/test_modeling_maskformer.py +++ b/tests/models/maskformer/test_modeling_maskformer.py @@ -387,9 +387,12 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): self.assertEqual( masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4) ) - expected_slice = torch.tensor( - [[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]] - ).to(torch_device) + expected_slice = [ + [-1.3737124, -1.7724937, -1.9364233], + [-1.5977281, -1.9867939, -2.1523695], + [-1.5795398, -1.9269832, -2.093942], + ] + expected_slice = torch.tensor(expected_slice).to(torch_device) self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE)) # class_queries_logits class_queries_logits = outputs.class_queries_logits