Make HF implementation match original OLMo 2 models for lower precisions (#38131)
* Make HF implementation match OLMo models for lower precisions * Add test of 1B logits in bfloat16 * Run make fixup
This commit is contained in:
@@ -232,6 +232,18 @@ class Olmo2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
@require_torch
|
||||
class Olmo2IntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_model_1b_logits_bfloat16(self):
|
||||
input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]]
|
||||
model = Olmo2ForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B").to(torch.bfloat16)
|
||||
out = model(torch.tensor(input_ids)).logits.float()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-5.7094, -6.5548, -3.2527, -2.7847, -5.5092, -4.5223, -4.8427, -4.6867]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([2.4531, -5.7188, -5.1562, -4.8750, -6.7812, -4.0625, -4.4375, -4.5938, -7.5938, -5.0938, -3.9375, -3.6875, -5.0938, -3.1875, -5.6875, 0.2266, 1.2578, 1.1016, 0.8945, 0.4785, 0.2256, -0.3613, -0.4258, 0.1377, -0.1104, -7.1875, -5.2188, -6.8125, -0.9062, -2.9062]) # fmt: skip
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@slow
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]]
|
||||
|
||||
Reference in New Issue
Block a user