Fix OneFormer integration test (#38016)

* Fix integration tests

* format
This commit is contained in:
Pavel Iakubovskii
2025-05-12 15:02:41 +01:00
committed by GitHub
parent 8efe3a9d77
commit 4220039b29

View File

@@ -528,32 +528,22 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
expected_slice_hidden_state = torch.tensor( expected_slice_hidden_state = [[0.2723, 0.8280, 0.6026], [1.2699, 1.1257, 1.1444], [1.1344, 0.6153, 0.4177]]
[[0.2723, 0.8280, 0.6026], [1.2699, 1.1257, 1.1444], [1.1344, 0.6153, 0.4177]] expected_slice_hidden_state = torch.tensor(expected_slice_hidden_state).to(torch_device)
).to(torch_device) slice_hidden_state = outputs.encoder_hidden_states[-1][0, 0, :3, :3]
self.assertTrue( torch.testing.assert_close(slice_hidden_state, expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE)
torch.allclose(
outputs.encoder_hidden_states[-1][0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
expected_slice_hidden_state = torch.tensor( expected_slice_hidden_state = [[1.0581, 1.2276, 1.2003], [1.1903, 1.2925, 1.2862], [1.158, 1.2559, 1.3216]]
[[1.0581, 1.2276, 1.2003], [1.1903, 1.2925, 1.2862], [1.158, 1.2559, 1.3216]] expected_slice_hidden_state = torch.tensor(expected_slice_hidden_state).to(torch_device)
).to(torch_device) slice_hidden_state = outputs.pixel_decoder_hidden_states[0][0, 0, :3, :3]
self.assertTrue( torch.testing.assert_close(slice_hidden_state, expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE)
torch.allclose(
outputs.pixel_decoder_hidden_states[0][0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
expected_slice_hidden_state = torch.tensor( # fmt: off
[[3.0668, -1.1833, -5.1103], [3.344, -3.362, -5.1101], [2.6017, -4.3613, -4.1444]] expected_slice_hidden_state = [[3.0668, -1.1833, -5.1103], [3.344, -3.362, -5.1101], [2.6017, -4.3613, -4.1444]]
).to(torch_device) expected_slice_hidden_state = torch.tensor(expected_slice_hidden_state).to(torch_device)
self.assertTrue( slice_hidden_state = outputs.transformer_decoder_class_predictions[0, :3, :3]
torch.allclose( torch.testing.assert_close(slice_hidden_state, expected_slice_hidden_state, atol=TOLERANCE, rtol=TOLERANCE)
outputs.transformer_decoder_class_predictions[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE # fmt: on
)
)
def test_inference_universal_segmentation_head(self): def test_inference_universal_segmentation_head(self):
model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
@@ -573,18 +563,18 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
masks_queries_logits.shape, masks_queries_logits.shape,
(1, model.config.num_queries, inputs_shape[-2] // 4, (inputs_shape[-1] + 2) // 4), (1, model.config.num_queries, inputs_shape[-2] // 4, (inputs_shape[-1] + 2) // 4),
) )
expected_slice = [[[3.1848, 4.2141, 4.1993], [2.9000, 3.5721, 3.6603], [2.5358, 3.0883, 3.6168]]] expected_slice = [[3.1848, 4.2141, 4.1993], [2.9000, 3.5721, 3.6603], [2.5358, 3.0883, 3.6168]]
expected_slice = torch.tensor(expected_slice).to(torch_device) expected_slice = torch.tensor(expected_slice).to(torch_device)
torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) torch.testing.assert_close(masks_queries_logits[0, 0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
# class_queries_logits # class_queries_logits
class_queries_logits = outputs.class_queries_logits class_queries_logits = outputs.class_queries_logits
self.assertEqual( self.assertEqual(
class_queries_logits.shape, class_queries_logits.shape,
(1, model.config.num_queries, model.config.num_labels + 1), (1, model.config.num_queries, model.config.num_labels + 1),
) )
expected_slice = torch.tensor( expected_slice = [[3.0668, -1.1833, -5.1103], [3.3440, -3.3620, -5.1101], [2.6017, -4.3613, -4.1444]]
[[3.0668, -1.1833, -5.1103], [3.3440, -3.3620, -5.1101], [2.6017, -4.3613, -4.1444]] expected_slice = torch.tensor(expected_slice).to(torch_device)
).to(torch_device)
torch.testing.assert_close(class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) torch.testing.assert_close(class_queries_logits[0, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@require_torch_accelerator @require_torch_accelerator