diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index 9912d9b6fe..473bbe6db6 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -549,7 +549,7 @@ class GitModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.sequences.shape, expected_shape) self.assertEqual(generated_caption, "two cats laying on a pink blanket") self.assertTrue(outputs.scores[-1].shape, expected_shape) - expected_slice = torch.tensor([[-0.8805, -0.8803, -0.8799]], device=torch_device) + expected_slice = torch.tensor([-0.8805, -0.8803, -0.8799], device=torch_device) torch.testing.assert_close(outputs.scores[-1][0, :3], expected_slice, rtol=1e-4, atol=1e-4) def test_visual_question_answering(self):