use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -285,8 +285,8 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt", padding=padding).input_features
|
||||
self.assertEqual(input_features.shape, (1, 4, 1001, 64))
|
||||
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, idx_in_mel[0]], EXPECTED_VALUES[0], atol=1e-4))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, idx_in_mel[1]], EXPECTED_VALUES[1], atol=1e-4))
|
||||
torch.testing.assert_close(input_features[0, 0, idx_in_mel[0]], EXPECTED_VALUES[0], rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(input_features[0, 0, idx_in_mel[1]], EXPECTED_VALUES[1], rtol=1e-4, atol=1e-4)
|
||||
|
||||
self.assertTrue(torch.all(input_features[0, 0] == input_features[0, 1]))
|
||||
self.assertTrue(torch.all(input_features[0, 0] == input_features[0, 2]))
|
||||
@@ -408,8 +408,8 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
input_speech, return_tensors="pt", truncation="rand_trunc", padding=padding
|
||||
).input_features
|
||||
self.assertEqual(input_features.shape, (1, 1, 1001, 64))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, idx_in_mel[0]], EXPECTED_VALUES[0], atol=1e-4))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, idx_in_mel[1]], EXPECTED_VALUES[1], atol=1e-4))
|
||||
torch.testing.assert_close(input_features[0, 0, idx_in_mel[0]], EXPECTED_VALUES[0], rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(input_features[0, 0, idx_in_mel[1]], EXPECTED_VALUES[1], rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_integration_fusion_long_input(self):
|
||||
# fmt: off
|
||||
@@ -475,7 +475,7 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
set_seed(987654321)
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt", padding=padding).input_features
|
||||
self.assertEqual(input_features.shape, (1, 4, 1001, 64))
|
||||
self.assertTrue(torch.allclose(input_features[0, block_idx, MEL_BIN], EXPECTED_VALUES, atol=1e-3))
|
||||
torch.testing.assert_close(input_features[0, block_idx, MEL_BIN], EXPECTED_VALUES, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_integration_rand_trunc_long_input(self):
|
||||
# fmt: off
|
||||
@@ -544,4 +544,4 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
input_speech, return_tensors="pt", truncation="rand_trunc", padding=padding
|
||||
).input_features
|
||||
self.assertEqual(input_features.shape, (1, 1, 1001, 64))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, MEL_BIN], EXPECTED_VALUES, atol=1e-4))
|
||||
torch.testing.assert_close(input_features[0, 0, MEL_BIN], EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user