GitModelIntegrationTest - flatten the expected slice tensor (#36260)
Flatten the expected slice tensor
This commit is contained in:
@@ -549,7 +549,7 @@ class GitModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.sequences.shape, expected_shape)
|
self.assertEqual(outputs.sequences.shape, expected_shape)
|
||||||
self.assertEqual(generated_caption, "two cats laying on a pink blanket")
|
self.assertEqual(generated_caption, "two cats laying on a pink blanket")
|
||||||
self.assertTrue(outputs.scores[-1].shape, expected_shape)
|
self.assertTrue(outputs.scores[-1].shape, expected_shape)
|
||||||
expected_slice = torch.tensor([[-0.8805, -0.8803, -0.8799]], device=torch_device)
|
expected_slice = torch.tensor([-0.8805, -0.8803, -0.8799], device=torch_device)
|
||||||
torch.testing.assert_close(outputs.scores[-1][0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(outputs.scores[-1][0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
def test_visual_question_answering(self):
|
def test_visual_question_answering(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user