From 20a04497a86dc79adb8a93e31325f25255d317e1 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:22:06 +0200 Subject: [PATCH] Fix `JetMoeIntegrationTest` (#32332) JetMoeIntegrationTest Co-authored-by: ydshieh --- tests/models/jetmoe/test_modeling_jetmoe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index cdb82cb5a9..50fd7a27e1 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -478,7 +478,7 @@ class JetMoeIntegrationTest(unittest.TestCase): @slow def test_model_8b_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] - model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto") + model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b") input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) with torch.no_grad(): out = model(input_ids).logits.cpu() @@ -498,7 +498,7 @@ class JetMoeIntegrationTest(unittest.TestCase): EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love""" prompt = "My favourite condiment is " tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False) - model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto") + model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b") input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) # greedy generation outputs @@ -521,7 +521,7 @@ class JetMoeIntegrationTest(unittest.TestCase): "My favourite ", ] tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False) - model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto") + model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b") input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device) print(input_ids)