use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -1142,7 +1142,7 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
[-1.4502, -4.6683, -0.5347, -2.9291, 9.1435, -3.0571, 8.9764, 1.7560, 8.7358, -1.5311]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :10], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_inference_printed(self):
|
||||
@@ -1176,7 +1176,7 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :10], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
@require_vision
|
||||
@@ -1272,7 +1272,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([24.3873, -6.4491, 32.5394]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# step 2: generation
|
||||
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
||||
@@ -1336,7 +1336,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-27.4344, -3.2686, -19.3524], device=torch_device)
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# step 2: generation
|
||||
task_prompt = "<s_cord-v2>"
|
||||
@@ -1398,7 +1398,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-17.6490, -4.8381, -15.7577], device=torch_device)
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# step 2: generation
|
||||
task_prompt = "<s_rvlcdip>"
|
||||
@@ -1475,7 +1475,7 @@ class NougatModelIntegrationTest(unittest.TestCase):
|
||||
[1.6253, -4.2179, 5.8532, -2.7911, -5.0609, -4.7397, -4.2890, -5.1073, -4.8908, -4.9729]
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :10], expected_slice, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :10], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_generation(self):
|
||||
processor = self.default_processor
|
||||
|
||||
Reference in New Issue
Block a user