[qwen3] fix generation tests (#37142)
* do not skip tests * fix qwen3-moe as well * fixup * fixup
This commit is contained in:
committed by
GitHub
parent
e686fed635
commit
8805600406
@@ -352,7 +352,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
def test_Mistral_sequence_classification_model(self):
|
def test_Mistral_sequence_classification_model(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
config.num_labels = 3
|
||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
|||||||
@@ -351,7 +351,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
def test_Mixtral_sequence_classification_model(self):
|
def test_Mixtral_sequence_classification_model(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
config.num_labels = 3
|
||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
|||||||
@@ -363,7 +363,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
|
|
||||||
def test_Qwen2_sequence_classification_model(self):
|
def test_Qwen2_sequence_classification_model(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
config.num_labels = 3
|
||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
|||||||
@@ -391,7 +391,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
|
|
||||||
def test_Qwen2Moe_sequence_classification_model(self):
|
def test_Qwen2Moe_sequence_classification_model(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
config.num_labels = 3
|
||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ class Qwen3ModelTester:
|
|||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=32,
|
hidden_size=64,
|
||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
max_window_layers=3,
|
max_window_layers=3,
|
||||||
use_sliding_window=True,
|
use_sliding_window=True,
|
||||||
@@ -348,42 +348,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
self.model_tester = Qwen3ModelTester(self)
|
self.model_tester = Qwen3ModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=Qwen3Config, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=Qwen3Config, hidden_size=37)
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_assisted_decoding_matches_greedy_search_1_same(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_assisted_decoding_sample(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_dola_decoding_sample(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_generate_compilation_all_outputs(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
@@ -402,7 +366,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
|
|
||||||
def test_Qwen3_sequence_classification_model(self):
|
def test_Qwen3_sequence_classification_model(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
config.num_labels = 3
|
||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
@@ -461,9 +424,9 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Qwen3 uses GQA on all models so the KV cache is a non standard format")
|
# Ignore copy
|
||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
super().test_past_key_values_format()
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@@ -487,7 +450,6 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||||
# slicing logits[0, 0, 0:30]
|
# slicing logits[0, 0, 0:30]
|
||||||
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
|
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
|
||||||
print(out[0, 0, :30])
|
|
||||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class Qwen3MoeModelTester:
|
|||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=32,
|
hidden_size=64,
|
||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
max_window_layers=3,
|
max_window_layers=3,
|
||||||
use_sliding_window=True,
|
use_sliding_window=True,
|
||||||
@@ -367,38 +367,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
self.model_tester = Qwen3MoeModelTester(self)
|
self.model_tester = Qwen3MoeModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37)
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_assisted_decoding_matches_greedy_search_1_same(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_assisted_decoding_sample(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_dola_decoding_sample(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("TODO: ask the contributor to take a look")
|
|
||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
@@ -417,7 +385,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
|
|
||||||
def test_Qwen3Moe_sequence_classification_model(self):
|
def test_Qwen3Moe_sequence_classification_model(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
config.num_labels = 3
|
||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
@@ -476,9 +443,9 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Qwen3Moe uses GQA on all models so the KV cache is a non standard format")
|
# Ignore copy
|
||||||
def test_past_key_values_format(self):
|
def test_past_key_values_format(self):
|
||||||
pass
|
super().test_past_key_values_format()
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@@ -539,7 +506,6 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
|
|||||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||||
# slicing logits[0, 0, 0:30]
|
# slicing logits[0, 0, 0:30]
|
||||||
EXPECTED_SLICE = torch.tensor([7.5938, 2.6094, 4.0312, 4.0938, 2.5156, 2.7812, 2.9688, 1.5547, 1.3984, 2.2344, 3.0156, 3.1562, 1.1953, 3.2500, 1.0938, 8.4375, 9.5625, 9.0625, 7.5625, 7.5625, 7.9062, 7.2188, 7.0312, 6.9375, 8.0625, 1.7266, 0.9141, 3.7969, 5.3438, 3.9844]) # fmt: skip
|
EXPECTED_SLICE = torch.tensor([7.5938, 2.6094, 4.0312, 4.0938, 2.5156, 2.7812, 2.9688, 1.5547, 1.3984, 2.2344, 3.0156, 3.1562, 1.1953, 3.2500, 1.0938, 8.4375, 9.5625, 9.0625, 7.5625, 7.5625, 7.9062, 7.2188, 7.0312, 6.9375, 8.0625, 1.7266, 0.9141, 3.7969, 5.3438, 3.9844]) # fmt: skip
|
||||||
print(out[0, 0, :30])
|
|
||||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
|
|||||||
Reference in New Issue
Block a user