use torch.testing.assertclose instead to get more details about error in cis (#35659)

* use torch.testing.assertclose instead to get more details about error in cis

* fix

* style

* test_all

* revert for I bert

* fixes and updates

* more image processing fixes

* more image processors

* fix mamba and co

* style

* less strick

* ok I won't be strict

* skip and be done

* up
This commit is contained in:
Arthur
2025-01-24 16:55:28 +01:00
committed by GitHub
parent 72d1a4cd53
commit b912f5ee43
255 changed files with 1048 additions and 969 deletions

View File

@@ -674,12 +674,12 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
device=torch_device,
)
self.assertTrue(
torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)
torch.testing.assert_close(
prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, rtol=1e-4, atol=1e-4
)
expected_seq_relationship_logits = torch.tensor([[46.9465, 47.9517]], device=torch_device)
self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
torch.testing.assert_close(seq_relationship_logits, expected_seq_relationship_logits, rtol=1e-4, atol=1e-4)
def test_inference_full_pretraining(self):
model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")
@@ -703,12 +703,12 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
],
device=torch_device,
)
self.assertTrue(
torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)
torch.testing.assert_close(
prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, rtol=1e-4, atol=1e-4
)
expected_seq_relationship_logits = torch.tensor([[41.4503, 41.2406]], device=torch_device)
self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
torch.testing.assert_close(seq_relationship_logits, expected_seq_relationship_logits, rtol=1e-4, atol=1e-4)
def test_block_sparse_attention_probs(self):
"""
@@ -773,7 +773,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
cl = torch.einsum("bhqk,bhkd->bhqd", attention_probs, value_layer)
cl = cl.view(context_layer.size())
self.assertTrue(torch.allclose(context_layer, cl, atol=0.001))
torch.testing.assert_close(context_layer, cl, rtol=0.001, atol=0.001)
def test_block_sparse_context_layer(self):
model = BigBirdModel.from_pretrained(
@@ -822,7 +822,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
context_layer = context_layer[0]
self.assertEqual(context_layer.shape, torch.Size((1, 128, 768)))
self.assertTrue(torch.allclose(context_layer[0, 64:78, 300:310], targeted_cl, atol=0.0001))
torch.testing.assert_close(context_layer[0, 64:78, 300:310], targeted_cl, rtol=0.0001, atol=0.0001)
def test_tokenizer_inference(self):
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
@@ -871,7 +871,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
device=torch_device,
)
self.assertTrue(torch.allclose(prediction[0, 52:64, 320:324], expected_prediction, atol=1e-4))
torch.testing.assert_close(prediction[0, 52:64, 320:324], expected_prediction, rtol=1e-4, atol=1e-4)
def test_inference_question_answering(self):
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-base-trivia-itc")
@@ -923,8 +923,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
)
# fmt: on
self.assertTrue(torch.allclose(start_logits[:, 64:96], target_start_logits, atol=1e-4))
self.assertTrue(torch.allclose(end_logits[:, 64:96], target_end_logits, atol=1e-4))
torch.testing.assert_close(start_logits[:, 64:96], target_start_logits, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(end_logits[:, 64:96], target_end_logits, rtol=1e-4, atol=1e-4)
input_ids = inputs["input_ids"].tolist()
answer = [
@@ -966,4 +966,4 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
# fmt: on
self.assertEqual(output.shape, torch.Size((1, 241, 768)))
self.assertTrue(torch.allclose(output[0, 64:78, 300:310], target, atol=0.0001))
torch.testing.assert_close(output[0, 64:78, 300:310], target, rtol=0.0001, atol=0.0001)