use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -240,7 +240,7 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
|
||||
self.assertEqual(input_features.shape, (1, 80, 3000))
|
||||
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
torch.testing.assert_close(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
|
||||
def test_numpy_integration(self):
|
||||
@@ -302,4 +302,4 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEqual(input_features.shape, (3, 80, 3000))
|
||||
self.assertTrue(torch.allclose(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
|
||||
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@@ -499,7 +499,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
embeds = model.get_encoder().embed_positions.weight
|
||||
self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
|
||||
torch.testing.assert_close(embeds, sinusoids(*embeds.shape))
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -924,7 +924,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
logits_fa = outputs_fa.decoder_hidden_states[-1]
|
||||
|
||||
# whisper FA2 needs very high tolerance
|
||||
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-1))
|
||||
torch.testing.assert_close(logits_fa, logits, rtol=4e-1, atol=4e-1)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
@@ -969,7 +969,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
logits_fa = outputs_fa.decoder_hidden_states[-1]
|
||||
|
||||
# whisper FA2 needs very high tolerance
|
||||
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-1))
|
||||
torch.testing.assert_close(logits_fa, logits, rtol=4e-1, atol=4e-1)
|
||||
|
||||
other_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@@ -984,7 +984,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
logits_fa = outputs_fa.decoder_hidden_states[-1]
|
||||
|
||||
# whisper FA2 needs very high tolerance
|
||||
self.assertTrue(torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1))
|
||||
torch.testing.assert_close(logits_fa[:, -2:], logits[:, -2:], rtol=4e-1, atol=4e-1)
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
@@ -1663,7 +1663,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_GENERATION = torch.tensor(
|
||||
@@ -1677,7 +1677,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
head_logits = logits[0] @ model.decoder.embed_tokens.weight.T
|
||||
self.assertTrue(torch.allclose(head_logits[0, 0, :30].cpu(), EXPECTED_GENERATION, atol=1e-4))
|
||||
torch.testing.assert_close(head_logits[0, 0, :30].cpu(), EXPECTED_GENERATION, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_small_en_logits_librispeech(self):
|
||||
@@ -1712,7 +1712,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_large_logits_librispeech(self):
|
||||
@@ -1756,7 +1756,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_tiny_en_generation(self):
|
||||
@@ -1868,7 +1868,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generated_ids.cpu(), EXPECTED_LOGITS))
|
||||
torch.testing.assert_close(generated_ids.cpu(), EXPECTED_LOGITS)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
@@ -1942,7 +1942,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_LOGITS)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
@@ -1975,7 +1975,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
@@ -2216,7 +2216,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430
|
||||
])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
@@ -2292,7 +2292,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||
torch.testing.assert_close(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
def test_small_token_timestamp_generation(self):
|
||||
@@ -2322,7 +2322,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||
torch.testing.assert_close(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT)
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_batch_generation(self):
|
||||
@@ -2403,7 +2403,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
|
||||
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
|
||||
torch.testing.assert_close(segment["token_timestamps"], exp_segment)
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
@@ -2438,7 +2438,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
torch.testing.assert_close(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_generate_with_prompt_ids(self):
|
||||
|
||||
Reference in New Issue
Block a user