Tests: upcast logits to float() (#34042)

upcast
This commit is contained in:
Joao Gante
2024-10-11 11:51:49 +01:00
committed by GitHub
parent 4b9bfd32f0
commit e878eaa9fc
11 changed files with 35 additions and 19 deletions

View File

@@ -524,7 +524,7 @@ class MistralIntegrationTest(unittest.TestCase):
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
out = model(input_ids).logits.float().cpu()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)