Fix qwen3 tests (#38862)

* fix

* update

* update

* update

* update

* update

* update

* format

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-06-17 15:21:36 +02:00
committed by GitHub
parent 41e0c921cb
commit 2507169bf6

View File

@@ -13,7 +13,6 @@
# limitations under the License.
"""Testing suite for the PyTorch Qwen3 model."""
import gc
import unittest
import pytest
@@ -23,7 +22,7 @@ from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_see
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
Expectations,
backend_empty_cache,
cleanup,
require_bitsandbytes,
require_flash_attn,
require_torch,
@@ -109,6 +108,12 @@ class Qwen3ModelTest(CausalLMModelTest, unittest.TestCase):
@require_torch
class Qwen3IntegrationTest(unittest.TestCase):
def setUp(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@slow
def test_model_600m_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
@@ -117,15 +122,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
with torch.no_grad():
out = model(input_ids).logits.float().cpu()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-1.4577, 1.3261, 3.8498, 3.4229, 2.9009, 1.8813, 2.1530, 2.1431]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
EXPECTED_MEAN = torch.tensor([[-1.3789, 1.3029, 3.8262, 3.4637, 2.8796, 1.8357, 2.1290, 2.1814]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-4, atol=1e-4)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
EXPECTED_SLICE = torch.tensor([4.6905, 4.9243, 4.7101, 3.2052, 2.2683, 1.6576, 3.6529, 3.9800, 3.2605, 2.6475, 3.0468, 4.2296, 5.7443, 4.8940, 4.4883, 6.0323, 7.4057, 7.3710, 6.8373, 6.6323, 6.7114, 6.3069, 6.1751, 6.0416, 6.0793, 4.6975, 2.3286, 3.6387, 2.0757, 1.9813]) # fmt: skip
del model
backend_empty_cache(torch_device)
gc.collect()
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
@slow
def test_model_600m_generation(self):
@@ -140,10 +142,6 @@ class Qwen3IntegrationTest(unittest.TestCase):
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()
@require_bitsandbytes
@slow
@require_flash_attn
@@ -169,20 +167,16 @@ class Qwen3IntegrationTest(unittest.TestCase):
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
del model
backend_empty_cache(torch_device)
gc.collect()
@slow
@require_torch_sdpa
def test_model_600m_long_prompt_sdpa(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
EXPECTED_OUTPUT_TOKEN_IDS = [198, 198]
# An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", attn_implementation="sdpa")
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
@@ -190,12 +184,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
backend_empty_cache(torch_device)
gc.collect()
cleanup(torch_device, gc_collect=True)
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"
prompt = "My favourite condiment is "
@@ -206,13 +200,19 @@ class Qwen3IntegrationTest(unittest.TestCase):
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_speculative_generation(self):
EXPECTED_TEXT_COMPLETION = (
"My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it"
)
EXPECTED_TEXT_COMPLETIONS = Expectations(
{
("cuda", 7): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the",
("cuda", 8): "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it",
}
) # fmt: skip
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False)
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16)
@@ -227,11 +227,8 @@ class Qwen3IntegrationTest(unittest.TestCase):
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_export_static_cache(self):