Tests: upcast logits to float() (#34042)

upcast
This commit is contained in:
Joao Gante
2024-10-11 11:51:49 +01:00
committed by GitHub
parent 4b9bfd32f0
commit e878eaa9fc
11 changed files with 35 additions and 19 deletions

View File

@@ -496,7 +496,7 @@ class PersimmonIntegrationTest(unittest.TestCase):
model = PersimmonForCausalLM.from_pretrained(
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
)
out = model(torch.tensor([input_ids], device=torch_device)).logits
out = model(torch.tensor([input_ids], device=torch_device)).logits.float()
EXPECTED_MEAN = torch.tensor(
[[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]]