@@ -481,7 +481,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
with torch.no_grad():
|
||||
out = model(input_ids).logits.cpu()
|
||||
out = model(input_ids).logits.float().cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[0.2507, -2.7073, -1.3445, -1.9363, -1.7216, -1.7370, -1.9054, -1.9792]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user