From 3a6ab46a0b85479d6fb0d6ce0bff2e48b4751ac4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 1 Apr 2025 17:09:29 +0800 Subject: [PATCH] add gpt2 test on XPU (#37028) * add gpt2 test on XPU Signed-off-by: jiqing-feng * auto dtype has been fixed Signed-off-by: jiqing-feng * convert model to train mode Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng --- tests/quantization/bnb/test_4bit.py | 1 - tests/quantization/bnb/test_mixed_int8.py | 9 ++------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index bf137a6af5..2d40e90104 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -626,7 +626,6 @@ class Bnb4BitTestTraining(Base4bitTest): @apply_skip_if_not_implemented -@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed") class Bnb4BitGPT2Test(Bnb4BitTest): model_name = "openai-community/gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 3.3191854854152187 diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index bc7804de9b..26191baa4e 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -889,6 +889,7 @@ class MixedInt8TestTraining(BaseMixedInt8Test): # Step 1: freeze all parameters model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True) + model.train() if torch.cuda.is_available(): self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) @@ -914,14 +915,9 @@ class MixedInt8TestTraining(BaseMixedInt8Test): batch = self.tokenizer("Test batch ", return_tensors="pt").to(torch_device) # Step 4: Check if the gradient is not None - if torch_device in {"xpu", "cpu"}: - # XPU and CPU finetune do not support autocast for now. + with torch.autocast(torch_device): out = model.forward(**batch) out.logits.norm().backward() - else: - with torch.autocast(torch_device): - out = model.forward(**batch) - out.logits.norm().backward() for module in model.modules(): if isinstance(module, LoRALayer): @@ -932,7 +928,6 @@ class MixedInt8TestTraining(BaseMixedInt8Test): @apply_skip_if_not_implemented -@unittest.skipIf(torch_device == "xpu", reason="XPU has precision issue on gpt model, will test it once fixed") class MixedInt8GPT2Test(MixedInt8Test): model_name = "openai-community/gpt2-xl" EXPECTED_RELATIVE_DIFFERENCE = 1.8720077507258357