use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -1227,7 +1227,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = torch.tensor(
|
||||
[[[-7.7729, -8.0343, -8.26001], [-7.74213, -7.8629, -8.6000], [-7.7328, -7.8269, -8.5264]]]
|
||||
).to(torch_device)
|
||||
# self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4))
|
||||
# torch.testing.assert_close(output_predited_logits[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
# encoder outputs
|
||||
@@ -1237,7 +1237,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
expected_shape_encoder = torch.Size((1, 28, 1024))
|
||||
self.assertEqual(encoder_outputs.shape, expected_shape_encoder)
|
||||
# self.assertTrue(torch.allclose(encoder_outputs[:, :3, :3], expected_encoder_outputs_slice, atol=1e-4))
|
||||
# torch.testing.assert_close(encoder_outputs[:, :3, :3], expected_encoder_outputs_slice, rtol=1e-4, atol=1e-4)
|
||||
assert torch.allclose(encoder_outputs[:, :3, :3], expected_encoder_outputs_slice, atol=1e-4)
|
||||
|
||||
# decoder outputs
|
||||
@@ -1245,7 +1245,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
|
||||
predicting_streams = decoder_outputs[1].view(1, model.config.ngram, 12, -1)
|
||||
predicting_streams_logits = model.lm_head(predicting_streams)
|
||||
next_first_stream_logits = predicting_streams_logits[:, 0]
|
||||
# self.assertTrue(torch.allclose(next_first_stream_logits[:, :3, :3], expected_slice, atol=1e-4))
|
||||
# torch.testing.assert_close(next_first_stream_logits[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
|
||||
assert torch.allclose(next_first_stream_logits[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user