use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -674,12 +674,12 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)
|
||||
torch.testing.assert_close(
|
||||
prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
expected_seq_relationship_logits = torch.tensor([[46.9465, 47.9517]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
|
||||
torch.testing.assert_close(seq_relationship_logits, expected_seq_relationship_logits, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_full_pretraining(self):
|
||||
model = BigBirdForPreTraining.from_pretrained("google/bigbird-roberta-base", attention_type="original_full")
|
||||
@@ -703,12 +703,12 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, atol=1e-4)
|
||||
torch.testing.assert_close(
|
||||
prediction_logits[0, 128:132, 128:132], expected_prediction_logits_slice, rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
expected_seq_relationship_logits = torch.tensor([[41.4503, 41.2406]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(seq_relationship_logits, expected_seq_relationship_logits, atol=1e-4))
|
||||
torch.testing.assert_close(seq_relationship_logits, expected_seq_relationship_logits, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_block_sparse_attention_probs(self):
|
||||
"""
|
||||
@@ -773,7 +773,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
cl = torch.einsum("bhqk,bhkd->bhqd", attention_probs, value_layer)
|
||||
cl = cl.view(context_layer.size())
|
||||
|
||||
self.assertTrue(torch.allclose(context_layer, cl, atol=0.001))
|
||||
torch.testing.assert_close(context_layer, cl, rtol=0.001, atol=0.001)
|
||||
|
||||
def test_block_sparse_context_layer(self):
|
||||
model = BigBirdModel.from_pretrained(
|
||||
@@ -822,7 +822,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
context_layer = context_layer[0]
|
||||
|
||||
self.assertEqual(context_layer.shape, torch.Size((1, 128, 768)))
|
||||
self.assertTrue(torch.allclose(context_layer[0, 64:78, 300:310], targeted_cl, atol=0.0001))
|
||||
torch.testing.assert_close(context_layer[0, 64:78, 300:310], targeted_cl, rtol=0.0001, atol=0.0001)
|
||||
|
||||
def test_tokenizer_inference(self):
|
||||
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
|
||||
@@ -871,7 +871,7 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(prediction[0, 52:64, 320:324], expected_prediction, atol=1e-4))
|
||||
torch.testing.assert_close(prediction[0, 52:64, 320:324], expected_prediction, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_inference_question_answering(self):
|
||||
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-base-trivia-itc")
|
||||
@@ -923,8 +923,8 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(start_logits[:, 64:96], target_start_logits, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(end_logits[:, 64:96], target_end_logits, atol=1e-4))
|
||||
torch.testing.assert_close(start_logits[:, 64:96], target_start_logits, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(end_logits[:, 64:96], target_end_logits, rtol=1e-4, atol=1e-4)
|
||||
|
||||
input_ids = inputs["input_ids"].tolist()
|
||||
answer = [
|
||||
@@ -966,4 +966,4 @@ class BigBirdModelIntegrationTest(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertEqual(output.shape, torch.Size((1, 241, 768)))
|
||||
self.assertTrue(torch.allclose(output[0, 64:78, 300:310], target, atol=0.0001))
|
||||
torch.testing.assert_close(output[0, 64:78, 300:310], target, rtol=0.0001, atol=0.0001)
|
||||
|
||||
Reference in New Issue
Block a user