Tests: upcast logits to float() (#34042)

upcast
This commit is contained in:
Joao Gante
2024-10-11 11:51:49 +01:00
committed by GitHub
parent 4b9bfd32f0
commit e878eaa9fc
11 changed files with 35 additions and 19 deletions

View File

@@ -482,7 +482,7 @@ class StableLmModelIntegrationTest(unittest.TestCase):
model = StableLmForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t").to(torch_device)
model.eval()
output = model(**input_ids).logits
output = model(**input_ids).logits.float()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[2.7146, 2.4245, 1.5616, 1.4424, 2.6790]]).to(torch_device)
@@ -515,7 +515,7 @@ class StableLmModelIntegrationTest(unittest.TestCase):
model = StableLmForCausalLM.from_pretrained("stabilityai/tiny-random-stablelm-2").to(torch_device)
model.eval()
output = model(**input_ids).logits
output = model(**input_ids).logits.float()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-2.7196, -3.6099, -2.6877, -3.1973, -3.9344]]).to(torch_device)