GitModelIntegrationTest - flatten the expected slice tensor (#36260)
Flatten the expected slice tensor
This commit is contained in:
@@ -549,7 +549,7 @@ class GitModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.sequences.shape, expected_shape)
|
||||
self.assertEqual(generated_caption, "two cats laying on a pink blanket")
|
||||
self.assertTrue(outputs.scores[-1].shape, expected_shape)
|
||||
expected_slice = torch.tensor([[-0.8805, -0.8803, -0.8799]], device=torch_device)
|
||||
expected_slice = torch.tensor([-0.8805, -0.8803, -0.8799], device=torch_device)
|
||||
torch.testing.assert_close(outputs.scores[-1][0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_visual_question_answering(self):
|
||||
|
||||
Reference in New Issue
Block a user