use torch.testing.assertclose instead to get more details about error in cis (#35659)

* use torch.testing.assertclose instead to get more details about error in cis

* fix

* style

* test_all

* revert for I bert

* fixes and updates

* more image processing fixes

* more image processors

* fix mamba and co

* style

* less strick

* ok I won't be strict

* skip and be done

* up
This commit is contained in:
Arthur
2025-01-24 16:55:28 +01:00
committed by GitHub
parent 72d1a4cd53
commit b912f5ee43
255 changed files with 1048 additions and 969 deletions

View File

@@ -1166,7 +1166,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
logits_2 = get_logits(model_2, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
# test that loading adapter weights with mismatched vocab sizes can be loaded
def test_load_target_lang_with_mismatched_size(self):
@@ -1203,7 +1203,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
logits_2 = get_logits(model_2, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
def test_load_attn_adapter(self):
processor = Wav2Vec2Processor.from_pretrained(
@@ -1250,7 +1250,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
model.load_adapter("ita", use_safetensors=True)
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
with tempfile.TemporaryDirectory() as tempdir:
model.save_pretrained(tempdir)
@@ -1271,7 +1271,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
logits = get_logits(model, input_features)
@@ -1282,7 +1282,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
logits_2 = get_logits(model, input_features)
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
@slow
def test_model_from_pretrained(self):
@@ -1595,7 +1595,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
], device=torch_device)
# fmt: on
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
torch.testing.assert_close(cosine_sim_masked, expected_cosine_sim_masked, rtol=1e-3, atol=1e-3)
def test_inference_pretrained(self):
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
@@ -1734,7 +1734,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([6.1186, 11.8961, 10.2931, 6.0898], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
torch.testing.assert_close(predicted_logits, expected_logits, rtol=1e-2, atol=1e-2)
def test_inference_intent_classification(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(torch_device)
@@ -1762,9 +1762,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
torch.testing.assert_close(predicted_logits_action, expected_logits_action, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(predicted_logits_object, expected_logits_object, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(predicted_logits_location, expected_logits_location, rtol=1e-2, atol=1e-2)
def test_inference_speaker_identification(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid").to(torch_device)
@@ -1785,7 +1785,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([37.5627, 71.6362, 64.2419, 31.7778], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
torch.testing.assert_close(predicted_logits, expected_logits, rtol=1e-2, atol=1e-2)
def test_inference_emotion_recognition(self):
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er").to(torch_device)
@@ -1804,7 +1804,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([2.1722, 3.0779, 8.0287, 6.6797], device=torch_device)
self.assertListEqual(predicted_ids.tolist(), expected_labels)
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
torch.testing.assert_close(predicted_logits, expected_logits, rtol=1e-2, atol=1e-2)
def test_phoneme_recognition(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").to(torch_device)
@@ -1936,7 +1936,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
)
self.assertEqual(labels[0, :, 0].sum(), 555)
self.assertEqual(labels[0, :, 1].sum(), 299)
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-2))
torch.testing.assert_close(outputs.logits[:, :4], expected_logits, rtol=1e-2, atol=1e-2)
def test_inference_speaker_verification(self):
model = Wav2Vec2ForXVector.from_pretrained("anton-l/wav2vec2-base-superb-sv").to(torch_device)