use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -390,10 +390,10 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
||||
out = model(input_ids).logits.float().cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[0.2507, -2.7073, -1.3445, -1.9363, -1.7216, -1.7370, -1.9054, -1.9792]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([-3.3689, 5.9006, 5.7450, -1.7012, -4.7072, -4.7071, -4.7071, -4.7071, -4.7072, -4.7072, -4.7072, -4.7071, 3.8321, 9.1746, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071]) # fmt: skip
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user