Tests: upcast logits to float() (#34042)

upcast
This commit is contained in:
Joao Gante
2024-10-11 11:51:49 +01:00
committed by GitHub
parent 4b9bfd32f0
commit e878eaa9fc
11 changed files with 35 additions and 19 deletions

View File

@@ -538,7 +538,7 @@ class GraniteIntegrationTest(unittest.TestCase):
self.assertTrue(
torch.allclose(
EXPECTED_SLICE.to(torch_device),
out.logits[0, 0, :15],
out.logits[0, 0, :15].float(),
atol=1e-3,
rtol=1e-3,
)
@@ -558,4 +558,4 @@ class GraniteIntegrationTest(unittest.TestCase):
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-2.0984, -3.1294, -2.8153, -2.3568, -2.7337, -2.2624, -2.6016, -2.4022]])
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), atol=1e-2, rtol=1e-2))