From 24e311f42b54f5f5fab6efcaa0c82eebd5608ba3 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Tue, 1 Apr 2025 20:52:55 +0800 Subject: [PATCH] fix XPU UT error case brough by RNG difference btw XPU and CUDA (#37121) * fix XPU UT error case brough by RNG difference btw XPU and CUDA Signed-off-by: YAO Matrix * enable tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits and tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits_bf16 on xpu Signed-off-by: YAO Matrix * Revert "enable tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits and tests/models/llama/test_modeling_llama.py::LlamaIntegrationTest::test_model_7b_logits_bf16 on xpu" This reverts commit 3ef83a4f0204642daa45fda56e8aca1afed24b4f. --------- Signed-off-by: YAO Matrix --- tests/generation/test_logits_process.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 7ba1502d42..9d7acdfcb5 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -976,7 +976,8 @@ class LogitsProcessorTest(unittest.TestCase): input_ids[:, -1] = 10 scores_wo_bias = scores[:, -1].clone() out = watermark(input_ids=input_ids, scores=scores) - self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all()) + greenlist_id = 3 if torch_device == "xpu" else 1 + self.assertTrue((out[:, greenlist_id] == scores_wo_bias + watermark.bias).all()) @parameterized.expand([(5, 3, 10000), (10, 5, 1000)]) def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size):