Fix qwen3 tests (#38862)
* fix * update * update * update * update * update * update * format --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch Qwen3 model."""
|
"""Testing suite for the PyTorch Qwen3 model."""
|
||||||
|
|
||||||
import gc
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -23,7 +22,7 @@ from transformers import AutoTokenizer, Qwen3Config, is_torch_available, set_see
|
|||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
Expectations,
|
Expectations,
|
||||||
backend_empty_cache,
|
cleanup,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -109,6 +108,12 @@ class Qwen3ModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Qwen3IntegrationTest(unittest.TestCase):
|
class Qwen3IntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_600m_logits(self):
|
def test_model_600m_logits(self):
|
||||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||||
@@ -117,15 +122,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = model(input_ids).logits.float().cpu()
|
out = model(input_ids).logits.float().cpu()
|
||||||
# Expected mean on dim = -1
|
# Expected mean on dim = -1
|
||||||
EXPECTED_MEAN = torch.tensor([[-1.4577, 1.3261, 3.8498, 3.4229, 2.9009, 1.8813, 2.1530, 2.1431]])
|
EXPECTED_MEAN = torch.tensor([[-1.3789, 1.3029, 3.8262, 3.4637, 2.8796, 1.8357, 2.1290, 2.1814]])
|
||||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-4, atol=1e-4)
|
||||||
# slicing logits[0, 0, 0:30]
|
# slicing logits[0, 0, 0:30]
|
||||||
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
|
EXPECTED_SLICE = torch.tensor([4.6905, 4.9243, 4.7101, 3.2052, 2.2683, 1.6576, 3.6529, 3.9800, 3.2605, 2.6475, 3.0468, 4.2296, 5.7443, 4.8940, 4.4883, 6.0323, 7.4057, 7.3710, 6.8373, 6.6323, 6.7114, 6.3069, 6.1751, 6.0416, 6.0793, 4.6975, 2.3286, 3.6387, 2.0757, 1.9813]) # fmt: skip
|
||||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
|
||||||
|
|
||||||
del model
|
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||||
backend_empty_cache(torch_device)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_600m_generation(self):
|
def test_model_600m_generation(self):
|
||||||
@@ -140,10 +142,6 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||||
|
|
||||||
del model
|
|
||||||
backend_empty_cache(torch_device)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@slow
|
@slow
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@@ -169,20 +167,16 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||||
|
|
||||||
del assistant_model
|
|
||||||
del model
|
|
||||||
backend_empty_cache(torch_device)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
def test_model_600m_long_prompt_sdpa(self):
|
def test_model_600m_long_prompt_sdpa(self):
|
||||||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
EXPECTED_OUTPUT_TOKEN_IDS = [198, 198]
|
||||||
# An input with 4097 tokens that is above the size of the sliding window
|
# An input with 4097 tokens that is above the size of the sliding window
|
||||||
input_ids = [1] + [306, 338] * 2048
|
input_ids = [1] + [306, 338] * 2048
|
||||||
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", attn_implementation="sdpa")
|
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", attn_implementation="sdpa")
|
||||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||||
|
|
||||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||||
|
|
||||||
# Assisted generation
|
# Assisted generation
|
||||||
@@ -190,12 +184,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
assistant_model.generation_config.num_assistant_tokens = 2
|
assistant_model.generation_config.num_assistant_tokens = 2
|
||||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
||||||
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
|
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||||
|
|
||||||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||||||
|
|
||||||
del assistant_model
|
del assistant_model
|
||||||
|
|
||||||
backend_empty_cache(torch_device)
|
cleanup(torch_device, gc_collect=True)
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"
|
EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"
|
||||||
prompt = "My favourite condiment is "
|
prompt = "My favourite condiment is "
|
||||||
@@ -206,13 +200,19 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
# greedy generation outputs
|
# greedy generation outputs
|
||||||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_speculative_generation(self):
|
def test_speculative_generation(self):
|
||||||
EXPECTED_TEXT_COMPLETION = (
|
EXPECTED_TEXT_COMPLETIONS = Expectations(
|
||||||
"My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it"
|
{
|
||||||
)
|
("cuda", 7): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the",
|
||||||
|
("cuda", 8): "My favourite condiment is 100% peanut butter. I love it so much that I can't help but use it",
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
|
||||||
|
|
||||||
prompt = "My favourite condiment is "
|
prompt = "My favourite condiment is "
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base", use_fast=False)
|
||||||
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16)
|
model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", device_map="auto", torch_dtype=torch.float16)
|
||||||
@@ -227,11 +227,8 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model
|
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model
|
||||||
)
|
)
|
||||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
|
||||||
|
|
||||||
del model
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||||
backend_empty_cache(torch_device)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_export_static_cache(self):
|
def test_export_static_cache(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user