use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -418,12 +418,12 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
with torch.no_grad():
|
||||
logits_single_first = model(input_ids=input_ids[:1, :-chunk_length], labels=labels[:1]).logits
|
||||
|
||||
self.assertTrue(torch.allclose(logits_batched[0, -3:], logits_single_first[0, -3:], atol=tolerance))
|
||||
torch.testing.assert_close(logits_batched[0, -3:], logits_single_first[0, -3:], rtol=tolerance, atol=tolerance)
|
||||
|
||||
with torch.no_grad():
|
||||
logits_single_second = model(input_ids=input_ids[1:], labels=labels[1:, :-4]).logits
|
||||
|
||||
self.assertTrue(torch.allclose(logits_batched[1, :3], logits_single_second[0, :3], atol=tolerance))
|
||||
torch.testing.assert_close(logits_batched[1, :3], logits_single_second[0, :3], rtol=tolerance, atol=tolerance)
|
||||
|
||||
def test_auto_padding(self):
|
||||
ids = [[7, 6, 9] * 65]
|
||||
@@ -445,7 +445,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
"logits"
|
||||
]
|
||||
|
||||
self.assertTrue(torch.allclose(output1, output2, atol=1e-5))
|
||||
torch.testing.assert_close(output1, output2, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def test_for_change_to_full_attn(self):
|
||||
self.model_tester.seq_length = 9
|
||||
@@ -462,7 +462,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
model.load_state_dict(state_dict)
|
||||
outputs2 = model(**input_dict)["logits"]
|
||||
|
||||
self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5))
|
||||
torch.testing.assert_close(outputs1, outputs2, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
|
||||
@@ -523,8 +523,8 @@ class BigBirdPegasusModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# fmt: on
|
||||
self.assertTrue(
|
||||
torch.allclose(prediction_logits[0, 4:8, 128:156], expected_prediction_logits_slice, atol=1e-4)
|
||||
torch.testing.assert_close(
|
||||
prediction_logits[0, 4:8, 128:156], expected_prediction_logits_slice, rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
def test_inference_full_attn(self):
|
||||
@@ -544,8 +544,8 @@ class BigBirdPegasusModelIntegrationTests(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(
|
||||
torch.allclose(prediction_logits[0, 4:8, 128:156], expected_prediction_logits_slice, atol=1e-4)
|
||||
torch.testing.assert_close(
|
||||
prediction_logits[0, 4:8, 128:156], expected_prediction_logits_slice, rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
def test_seq_to_seq_generation(self):
|
||||
|
||||
Reference in New Issue
Block a user