GitModelIntegrationTest - flatten the expected slice tensor (#36260)

Flatten the expected slice tensor
This commit is contained in:
ivarflakstad
2025-02-18 16:04:19 +01:00
committed by GitHub
parent 4d2de5f63c
commit 07182b2e10

View File

@@ -549,7 +549,7 @@ class GitModelIntegrationTest(unittest.TestCase):
self.assertEqual(outputs.sequences.shape, expected_shape)
self.assertEqual(generated_caption, "two cats laying on a pink blanket")
self.assertTrue(outputs.scores[-1].shape, expected_shape)
expected_slice = torch.tensor([[-0.8805, -0.8803, -0.8799]], device=torch_device)
expected_slice = torch.tensor([-0.8805, -0.8803, -0.8799], device=torch_device)
torch.testing.assert_close(outputs.scores[-1][0, :3], expected_slice, rtol=1e-4, atol=1e-4)
def test_visual_question_answering(self):