@@ -375,7 +375,7 @@ class OlmoeIntegrationTest(unittest.TestCase):
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]]
|
||||
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", device_map="auto")
|
||||
out = model(torch.tensor(input_ids)).logits
|
||||
out = model(torch.tensor(input_ids)).logits.float()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-1.3814, -3.4450, -2.2990, -1.9542, -2.4387, -2.7941, -2.9312, -2.8309]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user