use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -1821,7 +1821,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(logits.shape == (*decoder_input_ids.shape, model.decoder.config.vocab_size))
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :16].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :16].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_logits_text_audio_prompt(self):
|
||||
@@ -1859,7 +1859,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(logits.shape == (8, 50, 2048))
|
||||
self.assertTrue(torch.allclose(logits[0, -1, :16].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, -1, :16].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_unconditional_greedy(self):
|
||||
@@ -1881,7 +1881,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(output_values.shape == (1, 1, 3200))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_unconditional_sampling(self):
|
||||
@@ -1904,7 +1904,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(output_values.shape == (2, 1, 4480))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_text_prompt_greedy(self):
|
||||
@@ -1931,7 +1931,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(output_values.shape == (2, 1, 4480))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :10].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, :10].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_text_prompt_greedy_with_classifier_free_guidance(self):
|
||||
@@ -1958,7 +1958,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(output_values.shape == (2, 1, 4480))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_text_prompt_sampling(self):
|
||||
@@ -1986,7 +1986,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(output_values.shape == (2, 1, 4480))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_text_audio_prompt(self):
|
||||
@@ -2013,7 +2013,7 @@ class MusicgenIntegrationTests(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
output_values.shape == (2, 1, 36480)
|
||||
) # input values take shape 32000 and we generate from there
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -2053,8 +2053,8 @@ class MusicgenStereoIntegrationTests(unittest.TestCase):
|
||||
|
||||
# (bsz, channels, seq_len)
|
||||
self.assertTrue(output_values.shape == (1, 2, 5760))
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_text_audio_prompt(self):
|
||||
@@ -2088,5 +2088,5 @@ class MusicgenStereoIntegrationTests(unittest.TestCase):
|
||||
# (bsz, channels, seq_len)
|
||||
self.assertTrue(output_values.shape == (2, 2, 37760))
|
||||
# input values take shape 32000 and we generate from there - we check the last (generated) values
|
||||
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
|
||||
torch.testing.assert_close(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, rtol=1e-4, atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user