use torch.testing.assertclose instead to get more details about error in cis (#35659)
* use torch.testing.assertclose instead to get more details about error in cis * fix * style * test_all * revert for I bert * fixes and updates * more image processing fixes * more image processors * fix mamba and co * style * less strick * ok I won't be strict * skip and be done * up
This commit is contained in:
@@ -1095,7 +1095,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
torch.testing.assert_close(output_slice, expected_output_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_lsh_layer_forward_complex(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1118,7 +1118,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
torch.testing.assert_close(output_slice, expected_output_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_local_layer_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1136,7 +1136,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
torch.testing.assert_close(output_slice, expected_output_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_local_layer_forward_complex(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1158,7 +1158,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
torch.testing.assert_close(output_slice, expected_output_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_lsh_model_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1175,7 +1175,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
torch.testing.assert_close(output_slice, expected_output_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_local_model_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1191,7 +1191,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
torch.testing.assert_close(output_slice, expected_output_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_lm_model_forward(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1210,7 +1210,7 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
torch.testing.assert_close(output_slice, expected_output_slice, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_local_lm_model_grad(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1224,7 +1224,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
input_ids, _ = self._get_input_ids_and_mask()
|
||||
loss = model(input_ids=input_ids, labels=input_ids)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(loss, torch.tensor(5.8019, dtype=torch.float, device=torch_device), atol=1e-3))
|
||||
torch.testing.assert_close(
|
||||
loss, torch.tensor(5.8019, dtype=torch.float, device=torch_device), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
# check last grads to cover all proable errors
|
||||
@@ -1246,9 +1248,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3))
|
||||
torch.testing.assert_close(grad_slice_word, expected_grad_slice_word, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_lsh_lm_model_grad(self):
|
||||
config = self._get_basic_config_and_input()
|
||||
@@ -1264,7 +1266,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
input_ids, _ = self._get_input_ids_and_mask()
|
||||
loss = model(input_ids=input_ids, labels=input_ids)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(loss, torch.tensor(5.7854, dtype=torch.float, device=torch_device), atol=1e-3))
|
||||
torch.testing.assert_close(
|
||||
loss, torch.tensor(5.7854, dtype=torch.float, device=torch_device), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
loss.backward()
|
||||
# check last grads to cover all proable errors
|
||||
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
|
||||
@@ -1285,9 +1289,9 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(grad_slice_word, expected_grad_slice_word, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, atol=1e-3))
|
||||
torch.testing.assert_close(grad_slice_word, expected_grad_slice_word, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(grad_slice_position_factor_1, expected_grad_slice_pos_fac_1, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(grad_slice_position_factor_2, expected_grad_slice_pos_fac_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@slow
|
||||
def test_pretrained_generate_crime_and_punish(self):
|
||||
|
||||
Reference in New Issue
Block a user