@@ -580,7 +580,7 @@ class Qwen2MoeIntegrationTest(unittest.TestCase):
|
||||
model = Qwen2MoeForCausalLM.from_pretrained("Qwen/Qwen1.5-MoE-A2.7B", device_map="auto")
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
with torch.no_grad():
|
||||
out = model(input_ids).logits.cpu()
|
||||
out = model(input_ids).logits.float().cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-4.2125, -3.6416, -4.9136, -4.3005, -4.9938, -3.4393, -3.5195, -4.1621]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user