Fix qwen3_moe tests (#38865)

* try 1

* try 2

* try 3

* try 4

* try 5

* try 6

* try 7

* try 8

* try 9

* try 10

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-06-18 14:36:03 +02:00
committed by GitHub
parent 5a95ed5ca0
commit c77bcd889f
2 changed files with 47 additions and 68 deletions

View File

@@ -338,7 +338,7 @@ class AyaVisionIntegrationTest(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
del cls.model_checkpoint del cls.model
cleanup(torch_device, gc_collect=True) cleanup(torch_device, gc_collect=True)
def tearDown(self): def tearDown(self):

View File

@@ -13,18 +13,19 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Qwen3MoE model.""" """Testing suite for the PyTorch Qwen3MoE model."""
import gc
import unittest import unittest
import pytest import pytest
from transformers import AutoTokenizer, Qwen3MoeConfig, is_torch_available, set_seed from transformers import AutoTokenizer, Qwen3MoeConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache, cleanup,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_large_accelerator,
require_torch_multi_accelerator,
require_torch_sdpa, require_torch_sdpa,
slow, slow,
torch_device, torch_device,
@@ -143,34 +144,54 @@ class Qwen3MoeModelTest(CausalLMModelTest, unittest.TestCase):
self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item())
# Run on runners with larger accelerators (for example A10 instead of T4) with a lot of CPU RAM (e.g. g5-12xlarge)
@require_torch_multi_accelerator
@require_torch_large_accelerator
@require_torch @require_torch
class Qwen3MoeIntegrationTest(unittest.TestCase): class Qwen3MoeIntegrationTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = None
@classmethod
def tearDownClass(cls):
del cls.model
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@classmethod
def get_model(cls):
if cls.model is None:
cls.model = Qwen3MoeForCausalLM.from_pretrained(
"Qwen/Qwen3-30B-A3B-Base", device_map="auto", load_in_4bit=True
)
return cls.model
@slow @slow
def test_model_15b_a2b_logits(self): def test_model_15b_a2b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-15B-A2B-Base", device_map="auto") model = self.get_model()
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad(): with torch.no_grad():
out = model(input_ids).logits.float().cpu() out = model(input_ids).logits.float().cpu()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-1.1184, 1.1356, 9.2112, 8.0254, 5.1663, 7.9287, 8.9245, 10.0671]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([7.5938, 2.6094, 4.0312, 4.0938, 2.5156, 2.7812, 2.9688, 1.5547, 1.3984, 2.2344, 3.0156, 3.1562, 1.1953, 3.2500, 1.0938, 8.4375, 9.5625, 9.0625, 7.5625, 7.5625, 7.9062, 7.2188, 7.0312, 6.9375, 8.0625, 1.7266, 0.9141, 3.7969, 5.3438, 3.9844]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
del model # Expected mean on dim = -1
backend_empty_cache(torch_device) EXPECTED_MEAN = torch.tensor([[0.3244, 0.4406, 9.0972, 7.3597, 4.9985, 8.0314, 8.2148, 9.2134]])
gc.collect() torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([6.8984, 4.8633, 4.7734, 4.5898, 2.5664, 2.9902, 4.8828, 5.9414, 4.6250, 3.0840, 5.1602, 6.0117, 4.9453, 5.3008, 3.3145, 11.3906, 12.8359, 12.4844, 11.2891, 11.0547, 11.0391, 10.3359, 10.3438, 10.2578, 10.7969, 5.9688, 3.7676, 5.5938, 5.3633, 5.8203]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
@slow @slow
def test_model_15b_a2b_generation(self): def test_model_15b_a2b_generation(self):
EXPECTED_TEXT_COMPLETION = ( EXPECTED_TEXT_COMPLETION = "To be or not to be: the role of the cell cycle in the regulation of apoptosis.\nThe cell cycle is a highly"
"""To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl"""
)
prompt = "To be or not to" prompt = "To be or not to"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B-Base", use_fast=False)
model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-15B-A2B-Base", device_map="auto") model = self.get_model()
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs # greedy generation outputs
@@ -178,10 +199,6 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()
@require_bitsandbytes @require_bitsandbytes
@slow @slow
@require_flash_attn @require_flash_attn
@@ -191,7 +208,7 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048 input_ids = [1] + [306, 338] * 2048
model = Qwen3MoeForCausalLM.from_pretrained( model = Qwen3MoeForCausalLM.from_pretrained(
"Qwen/Qwen3-15B-A2B-Base", "Qwen/Qwen3-30B-A3B-Base",
device_map="auto", device_map="auto",
load_in_4bit=True, load_in_4bit=True,
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
@@ -200,50 +217,20 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
del model
backend_empty_cache(torch_device)
gc.collect()
@slow @slow
@require_torch_sdpa @require_torch_sdpa
def test_model_15b_a2b_long_prompt_sdpa(self): def test_model_15b_a2b_long_prompt_sdpa(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window # An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048 input_ids = [1] + [306, 338] * 2048
model = Qwen3MoeForCausalLM.from_pretrained( model = self.get_model()
"Qwen/Qwen3-15B-A2B-Base",
device_map="auto",
attn_implementation="sdpa",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation EXPECTED_TEXT_COMPLETION = "To be or not to be: the role of the cell cycle in the regulation of apoptosis.\nThe cell cycle is a highly"
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
backend_empty_cache(torch_device)
gc.collect()
EXPECTED_TEXT_COMPLETION = (
"""To be or not to be, that is the question. Whether 'tis nobler in the mind to suffer the sl"""
)
prompt = "To be or not to" prompt = "To be or not to"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B-Base", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
@@ -255,16 +242,12 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
@slow @slow
def test_speculative_generation(self): def test_speculative_generation(self):
EXPECTED_TEXT_COMPLETION = ( EXPECTED_TEXT_COMPLETION = (
"To be or not to be, that is the question: whether 'tis nobler in the mind to suffer the sl" "To be or not to be: the role of the liver in the pathogenesis of obesity and type 2 diabetes.\nThe"
) )
prompt = "To be or not to" prompt = "To be or not to"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-15B-A2B-Base", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B-Base", use_fast=False)
model = Qwen3MoeForCausalLM.from_pretrained( model = self.get_model()
"Qwen/Qwen3-15B-A2B-Base", device_map="auto", torch_dtype=torch.float16 assistant_model = model
)
assistant_model = Qwen3MoeForCausalLM.from_pretrained(
"Qwen/Qwen3-15B-A2B-Base", device_map="auto", torch_dtype=torch.float16
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs # greedy generation outputs
@@ -274,7 +257,3 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
) )
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
del model
backend_empty_cache(torch_device)
gc.collect()