Make HF implementation match original OLMo 2 models for lower precisions (#38131)

* Make HF implementation match OLMo models for lower precisions

* Add test of 1B logits in bfloat16

* Run make fixup
This commit is contained in:
Shane A
2025-05-19 06:35:23 -07:00
committed by GitHub
parent 9644acb7cb
commit aef12349b6
5 changed files with 135 additions and 69 deletions

View File

@@ -232,6 +232,18 @@ class Olmo2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch
class Olmo2IntegrationTest(unittest.TestCase):
@slow
def test_model_1b_logits_bfloat16(self):
input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]]
model = Olmo2ForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B").to(torch.bfloat16)
out = model(torch.tensor(input_ids)).logits.float()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-5.7094, -6.5548, -3.2527, -2.7847, -5.5092, -4.5223, -4.8427, -4.6867]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([2.4531, -5.7188, -5.1562, -4.8750, -6.7812, -4.0625, -4.4375, -4.5938, -7.5938, -5.0938, -3.9375, -3.6875, -5.0938, -3.1875, -5.6875, 0.2266, 1.2578, 1.1016, 0.8945, 0.4785, 0.2256, -0.3613, -0.4258, 0.1377, -0.1104, -7.1875, -5.2188, -6.8125, -0.9062, -2.9062]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
@slow
def test_model_7b_logits(self):
input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]]