Use assertAlmostEqual in BloomEmbeddingTest.test_logits (#19200)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -771,8 +771,8 @@ class BloomEmbeddingTest(unittest.TestCase):
|
|||||||
|
|
||||||
output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
|
output_gpu_1, output_gpu_2 = output.split(125440, dim=-1)
|
||||||
if cuda_available:
|
if cuda_available:
|
||||||
self.assertEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1)
|
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6)
|
||||||
self.assertEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2)
|
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
|
||||||
else:
|
else:
|
||||||
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!!
|
self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!!
|
||||||
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
|
self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6)
|
||||||
|
|||||||
Reference in New Issue
Block a user