use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -159,7 +159,7 @@ class EnCodecFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
feature_extractor = EncodecFeatureExtractor()
|
||||
input_values = feature_extractor(input_audio, return_tensors="pt").input_values
|
||||
self.assertEqual(input_values.shape, (1, 1, 93680))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-6))
|
||||
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-6, atol=1e-6)
|
||||
|
||||
def test_integration_stereo(self):
|
||||
# fmt: off
|
||||
@@ -178,8 +178,8 @@ class EnCodecFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
feature_extractor = EncodecFeatureExtractor(feature_size=2)
|
||||
input_values = feature_extractor(input_audio, return_tensors="pt").input_values
|
||||
self.assertEqual(input_values.shape, (1, 2, 93680))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-6))
|
||||
self.assertTrue(torch.allclose(input_values[0, 1, :30], EXPECTED_INPUT_VALUES * 0.5, atol=1e-6))
|
||||
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-6, atol=1e-6)
|
||||
torch.testing.assert_close(input_values[0, 1, :30], EXPECTED_INPUT_VALUES * 0.5, rtol=1e-6, atol=1e-6)
|
||||
|
||||
def test_truncation_and_padding(self):
|
||||
input_audio = self._load_datasamples(2)
|
||||
|
||||
@@ -324,7 +324,7 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
inputs["input_values"] = inputs["input_values"].repeat(1, 1, 10)
|
||||
|
||||
hidden_states_no_chunk = model(**inputs)[0]
|
||||
hidden_states_no_chunk = model(**inputs)[1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
config.chunk_length_s = 1
|
||||
@@ -335,8 +335,8 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_with_chunk = model(**inputs)[0]
|
||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
hidden_states_with_chunk = model(**inputs)[1]
|
||||
torch.testing.assert_close(hidden_states_no_chunk, hidden_states_with_chunk, rtol=1e-1, atol=1e-2)
|
||||
|
||||
@unittest.skip(
|
||||
reason="The EncodecModel is not transformers based, thus it does not have the usual `hidden_states` logic"
|
||||
@@ -507,7 +507,7 @@ class EncodecIntegrationTest(unittest.TestCase):
|
||||
)[-1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# make sure shape matches
|
||||
self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape)
|
||||
@@ -563,7 +563,7 @@ class EncodecIntegrationTest(unittest.TestCase):
|
||||
)[-1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# make sure shape matches
|
||||
self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape)
|
||||
@@ -622,7 +622,7 @@ class EncodecIntegrationTest(unittest.TestCase):
|
||||
input_values_enc_dec = model(input_values, bandwidth=float(bandwidth))[-1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
torch.testing.assert_close(input_values_dec, input_values_enc_dec, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# make sure shape matches
|
||||
self.assertTrue(input_values.shape == input_values_enc_dec.shape)
|
||||
|
||||
Reference in New Issue
Block a user