Tests: upcast logits to float() (#34042)

upcast
This commit is contained in:
Joao Gante
2024-10-11 11:51:49 +01:00
committed by GitHub
parent 4b9bfd32f0
commit e878eaa9fc
11 changed files with 35 additions and 19 deletions

View File

@@ -481,7 +481,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
out = model(input_ids).logits.float().cpu()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[0.2507, -2.7073, -1.3445, -1.9363, -1.7216, -1.7370, -1.9054, -1.9792]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)