use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -438,14 +438,14 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
torch.testing.assert_close(encoder_outputs_mean, expected_encoder_sums, rtol=1e-3, atol=1e-3)
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
arr = inputs["input_values"][0].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
|
||||
@@ -515,10 +515,10 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
input_values_from_codes = model.decode(audio_codes=encoder_outputs.audio_codes)[0]
|
||||
|
||||
# make sure decode from audio codes and quantized values give more or less the same results
|
||||
self.assertTrue(torch.allclose(input_values_from_codes, input_values_dec, atol=1e-5))
|
||||
torch.testing.assert_close(input_values_from_codes, input_values_dec, rtol=1e-5, atol=1e-5)
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
arr = inputs["input_values"][0].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
|
||||
@@ -565,14 +565,14 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
torch.testing.assert_close(encoder_outputs_mean, expected_encoder_sums, rtol=1e-3, atol=1e-3)
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
arr = inputs["input_values"][0].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
|
||||
@@ -622,14 +622,14 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
torch.testing.assert_close(encoder_outputs_mean, expected_encoder_sums, rtol=1e-3, atol=1e-3)
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
arr = inputs["input_values"].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec.cpu().numpy()
|
||||
@@ -679,14 +679,14 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
torch.testing.assert_close(encoder_outputs_mean, expected_encoder_sums, rtol=1e-3, atol=1e-3)
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
arr = inputs["input_values"].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec.cpu().numpy()
|
||||
@@ -736,14 +736,14 @@ class DacIntegrationTest(unittest.TestCase):
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
torch.testing.assert_close(encoder_outputs_mean, expected_encoder_sums, rtol=1e-3, atol=1e-3)
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
arr = inputs["input_values"].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec.cpu().numpy()
|
||||
|
||||
Reference in New Issue
Block a user