Fix qwen3 tests (#38862)
* fix * update * update * update * update * update * update * format --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Qwen3 model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -23,7 +22,7 @@ from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_see
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
cleanup,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@@ -109,6 +108,12 @@ class Qwen3ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class Qwen3IntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
def test_model_600m_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
@@ -117,15 +122,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
out = model(input_ids).logits.float().cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-1.4577, 1.3261, 3.8498, 3.4229, 2.9009, 1.8813, 2.1530, 2.1431]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||
EXPECTED_MEAN = torch.tensor([[-1.3789, 1.3029, 3.8262, 3.4637, 2.8796, 1.8357, 2.1290, 2.1814]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-4, atol=1e-4)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||
EXPECTED_SLICE = torch.tensor([4.6905, 4.9243, 4.7101, 3.2052, 2.2683, 1.6576, 3.6529, 3.9800, 3.2605, 2.6475, 3.0468, 4.2296, 5.7443, 4.8940, 4.4883, 6.0323, 7.4057, 7.3710, 6.8373, 6.6323, 6.7114, 6.3069, 6.1751, 6.0416, 6.0793, 4.6975, 2.3286, 3.6387, 2.0757, 1.9813]) # fmt: skip
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_model_600m_generation(self):
|
||||
@@ -140,10 +142,6 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
@@ -169,20 +167,16 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||
|
||||
del assistant_model
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_torch_sdpa
|
||||
def test_model_600m_long_prompt_sdpa(self):
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||||
EXPECTED_OUTPUT_TOKEN_IDS = [198, 198]
|
||||
# An input with 4097 tokens that is above the size of the sliding window
|
||||
input_ids = [1] + [306, 338] * 2048
|
||||
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", attn_implementation="sdpa")
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
|
||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||
|
||||
# Assisted generation
|
||||
@@ -190,12 +184,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
assistant_model.generation_config.num_assistant_tokens = 2
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
||||
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
|
||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||
|
||||
del assistant_model
|
||||
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"
|
||||
prompt = "My favourite condiment is "
|
||||
@@ -206,13 +200,19 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
def test_speculative_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it"
|
||||
)
|
||||
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
||||
{
|
||||
("cuda", 7): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the",
|
||||
("cuda", 8): "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it",
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
||||
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False)
|
||||
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16)
|
||||
@@ -227,11 +227,8 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
def test_export_static_cache(self):
|
||||
|
||||
Reference in New Issue
Block a user