use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -148,7 +148,7 @@ class Pop2PianoFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittes
|
||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||
[[-7.1493, -6.8701, -4.3214], [-5.9473, -5.7548, -3.8438], [-6.1324, -5.9018, -4.3778]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(input_features[0, :3, :3], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
torch.testing.assert_close(input_features[0, :3, :3], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_attention_mask(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
@@ -691,7 +691,7 @@ class Pop2PianoModelIntegrationTests(unittest.TestCase):
|
||||
[[1.0475305318832397, 0.29052114486694336, -0.47778210043907166], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs[0, :3, :3], EXPECTED_OUTPUTS, atol=1e-4))
|
||||
torch.testing.assert_close(outputs[0, :3, :3], EXPECTED_OUTPUTS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
@require_essentia
|
||||
|
||||
@@ -87,8 +87,8 @@ class Pop2PianoTokenizerTest(unittest.TestCase):
|
||||
)
|
||||
expected_output_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])
|
||||
|
||||
self.assertTrue(torch.allclose(output["token_ids"], expected_output_token_ids, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output["attention_mask"], expected_output_attention_mask, atol=1e-4))
|
||||
torch.testing.assert_close(output["token_ids"], expected_output_token_ids, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(output["attention_mask"], expected_output_attention_mask, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_batch_decode(self):
|
||||
# test batch decode with model, feature-extractor outputs(beatsteps, extrapolated_beatstep)
|
||||
@@ -174,7 +174,7 @@ class Pop2PianoTokenizerTest(unittest.TestCase):
|
||||
)
|
||||
predicted_start_timings = torch.tensor(predicted_start_timings)
|
||||
|
||||
self.assertTrue(torch.allclose(expected_start_timings, predicted_start_timings, atol=1e-4))
|
||||
torch.testing.assert_close(expected_start_timings, predicted_start_timings, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Checking note end timings
|
||||
expected_end_timings = torch.tensor(
|
||||
@@ -187,7 +187,7 @@ class Pop2PianoTokenizerTest(unittest.TestCase):
|
||||
)
|
||||
predicted_end_timings = torch.tensor(predicted_end_timings)
|
||||
|
||||
self.assertTrue(torch.allclose(expected_end_timings, predicted_end_timings, atol=1e-4))
|
||||
torch.testing.assert_close(expected_end_timings, predicted_end_timings, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_get_vocab(self):
|
||||
vocab_dict = self.tokenizer.get_vocab()
|
||||
|
||||
Reference in New Issue
Block a user