From 2507169bf658e39e6ffe89a04b32e3729b218b73 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 17 Jun 2025 15:21:36 +0200 Subject: [PATCH] Fix `qwen3` tests (#38862) * fix * update * update * update * update * update * update * format --------- Co-authored-by: ydshieh --- tests/models/qwen3/test_modeling_qwen3.py | 53 +++++++++++------------ 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 2664e1b692..3f3f5bae08 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -13,7 +13,6 @@ # limitations under the License. """Testing suite for the PyTorch Qwen3 model.""" -import gc import unittest import pytest @@ -23,7 +22,7 @@ from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_see from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( Expectations, - backend_empty_cache, + cleanup, require_bitsandbytes, require_flash_attn, require_torch, @@ -109,6 +108,12 @@ class Qwen3ModelTest(CausalLMModelTest, unittest.TestCase): @require_torch class Qwen3IntegrationTest(unittest.TestCase): + def setUp(self): + cleanup(torch_device, gc_collect=True) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + @slow def test_model_600m_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] @@ -117,15 +122,12 @@ class Qwen3IntegrationTest(unittest.TestCase): with torch.no_grad(): out = model(input_ids).logits.float().cpu() # Expected mean on dim = -1 - EXPECTED_MEAN = torch.tensor([[-1.4577, 1.3261, 3.8498, 3.4229, 2.9009, 1.8813, 2.1530, 2.1431]]) - torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + EXPECTED_MEAN = torch.tensor([[-1.3789, 1.3029, 3.8262, 3.4637, 2.8796, 1.8357, 2.1290, 2.1814]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-4, atol=1e-4) # slicing logits[0, 0, 0:30] - EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip - torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4) + EXPECTED_SLICE = torch.tensor([4.6905, 4.9243, 4.7101, 3.2052, 2.2683, 1.6576, 3.6529, 3.9800, 3.2605, 2.6475, 3.0468, 4.2296, 5.7443, 4.8940, 4.4883, 6.0323, 7.4057, 7.3710, 6.8373, 6.6323, 6.7114, 6.3069, 6.1751, 6.0416, 6.0793, 4.6975, 2.3286, 3.6387, 2.0757, 1.9813]) # fmt: skip - del model - backend_empty_cache(torch_device) - gc.collect() + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4) @slow def test_model_600m_generation(self): @@ -140,10 +142,6 @@ class Qwen3IntegrationTest(unittest.TestCase): text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - del model - backend_empty_cache(torch_device) - gc.collect() - @require_bitsandbytes @slow @require_flash_attn @@ -169,20 +167,16 @@ class Qwen3IntegrationTest(unittest.TestCase): generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) - del assistant_model - del model - backend_empty_cache(torch_device) - gc.collect() - @slow @require_torch_sdpa def test_model_600m_long_prompt_sdpa(self): - EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + EXPECTED_OUTPUT_TOKEN_IDS = [198, 198] # An input with 4097 tokens that is above the size of the sliding window input_ids = [1] + [306, 338] * 2048 model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", attn_implementation="sdpa") input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) # Assisted generation @@ -190,12 +184,12 @@ class Qwen3IntegrationTest(unittest.TestCase): assistant_model.generation_config.num_assistant_tokens = 2 assistant_model.generation_config.num_assistant_tokens_schedule = "constant" generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) del assistant_model - backend_empty_cache(torch_device) - gc.collect() + cleanup(torch_device, gc_collect=True) EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is" prompt = "My favourite condiment is " @@ -206,13 +200,19 @@ class Qwen3IntegrationTest(unittest.TestCase): # greedy generation outputs generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow def test_speculative_generation(self): - EXPECTED_TEXT_COMPLETION = ( - "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it" - ) + EXPECTED_TEXT_COMPLETIONS = Expectations( + { + ("cuda", 7): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the", + ("cuda", 8): "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it", + } + ) # fmt: skip + EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() + prompt = "My favourite condiment is " tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False) model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16) @@ -227,11 +227,8 @@ class Qwen3IntegrationTest(unittest.TestCase): input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model ) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - del model - backend_empty_cache(torch_device) - gc.collect() + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow def test_export_static_cache(self):