use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -1166,7 +1166,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
logits_2 = get_logits(model_2, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# test that loading adapter weights with mismatched vocab sizes can be loaded
|
||||
def test_load_target_lang_with_mismatched_size(self):
|
||||
@@ -1203,7 +1203,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
logits_2 = get_logits(model_2, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_load_attn_adapter(self):
|
||||
processor = Wav2Vec2Processor.from_pretrained(
|
||||
@@ -1250,7 +1250,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.load_adapter("ita", use_safetensors=True)
|
||||
logits_2 = get_logits(model, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model.save_pretrained(tempdir)
|
||||
@@ -1271,7 +1271,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
logits_2 = get_logits(model, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained("hf-internal-testing/tiny-random-wav2vec2-adapter")
|
||||
logits = get_logits(model, input_features)
|
||||
@@ -1282,7 +1282,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
logits_2 = get_logits(model, input_features)
|
||||
|
||||
self.assertTrue(torch.allclose(logits, logits_2, atol=1e-3))
|
||||
torch.testing.assert_close(logits, logits_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
@@ -1595,7 +1595,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
], device=torch_device)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
||||
torch.testing.assert_close(cosine_sim_masked, expected_cosine_sim_masked, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_inference_pretrained(self):
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||
@@ -1734,7 +1734,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_logits = torch.tensor([6.1186, 11.8961, 10.2931, 6.0898], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
torch.testing.assert_close(predicted_logits, expected_logits, rtol=1e-2, atol=1e-2)
|
||||
|
||||
def test_inference_intent_classification(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(torch_device)
|
||||
@@ -1762,9 +1762,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
|
||||
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
|
||||
|
||||
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
|
||||
torch.testing.assert_close(predicted_logits_action, expected_logits_action, rtol=1e-2, atol=1e-2)
|
||||
torch.testing.assert_close(predicted_logits_object, expected_logits_object, rtol=1e-2, atol=1e-2)
|
||||
torch.testing.assert_close(predicted_logits_location, expected_logits_location, rtol=1e-2, atol=1e-2)
|
||||
|
||||
def test_inference_speaker_identification(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid").to(torch_device)
|
||||
@@ -1785,7 +1785,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_logits = torch.tensor([37.5627, 71.6362, 64.2419, 31.7778], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
torch.testing.assert_close(predicted_logits, expected_logits, rtol=1e-2, atol=1e-2)
|
||||
|
||||
def test_inference_emotion_recognition(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er").to(torch_device)
|
||||
@@ -1804,7 +1804,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_logits = torch.tensor([2.1722, 3.0779, 8.0287, 6.6797], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
torch.testing.assert_close(predicted_logits, expected_logits, rtol=1e-2, atol=1e-2)
|
||||
|
||||
def test_phoneme_recognition(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft").to(torch_device)
|
||||
@@ -1936,7 +1936,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 555)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 299)
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-2))
|
||||
torch.testing.assert_close(outputs.logits[:, :4], expected_logits, rtol=1e-2, atol=1e-2)
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = Wav2Vec2ForXVector.from_pretrained("anton-l/wav2vec2-base-superb-sv").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user