🚨 🚨 Inherited CausalLM Tests (#37590)

* stash commit

* Experiment 1: Try just Gemma

* Experiment 1: Just try Gemma

* make fixup

* Trigger tests

* stash commit

* Try adding Gemma3 as well

* make fixup

* Correct attrib names

* Correct pipeline model mapping

* Add in all_model_classes for Gemma1 again

* Move the pipeline model mapping around again

* make fixup

* Revert Gemma3 changes since it's a VLM

* Let's try Falcon

* Correct attributes

* Correct attributes

* Let's try just overriding get_config() for now

* Do Nemotron too

* And Llama!

* Do llama/persimmon

* Correctly skip tests

* Fix Persimmon

* Include Phimoe

* Fix Gemma2

* Set model_tester_class correctly

* Add GLM

* More models!

* models models models

* make fixup

* Add Qwen3 + Qwen3MoE

* Correct import

* make fixup

* Add the QuestionAnswering classes

* Add the QuestionAnswering classes

* Move pipeline mapping to the right place

* Jetmoe too

* Stop RoPE testing models with no RoPE

* Fix up JetMOE a bit

* Fix up JetMOE a bit

* Can we just force pad_token_id all the time?

* make fixup

* fix starcoder2

* Move pipeline mapping

* Fix RoPE skipping

* Fix RecurrentGemma tests

* Fix Falcon tests

* Add MoE attributes

* Fix values for RoPE testing

* Make sure we set bos_token_id and eos_token_id in an appropriate range

* make fixup

* Fix GLM4

* Add mamba attributes

* Revert bits of JetMOE

* Re-add the JetMOE skips

* Update tests/causal_lm_tester.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add licence

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Matt
2025-05-23 18:29:31 +01:00
committed by GitHub
parent d5f992f5e6
commit 53fb245eb6
25 changed files with 816 additions and 4422 deletions

View File

@@ -28,8 +28,7 @@ from transformers.testing_utils import (
torch_device,
)
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
from ...test_configuration_common import ConfigTester
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
if is_torch_available():
@@ -43,17 +42,18 @@ if is_torch_available():
)
class Glm4ModelTester(GemmaModelTester):
class Glm4ModelTester(CausalLMModelTester):
if is_torch_available():
config_class = Glm4Config
model_class = Glm4Model
for_causal_lm_class = Glm4ForCausalLM
for_sequence_class = Glm4ForSequenceClassification
for_token_class = Glm4ForTokenClassification
base_model_class = Glm4Model
causal_lm_class = Glm4ForCausalLM
sequence_classification_class = Glm4ForSequenceClassification
token_classification_class = Glm4ForTokenClassification
@require_torch
class Glm4ModelTest(GemmaModelTest, unittest.TestCase):
class Glm4ModelTest(CausalLMModelTest, unittest.TestCase):
model_tester_class = Glm4ModelTester
all_model_classes = (
(Glm4Model, Glm4ForCausalLM, Glm4ForSequenceClassification, Glm4ForTokenClassification)
if is_torch_available()
@@ -75,10 +75,6 @@ class Glm4ModelTest(GemmaModelTest, unittest.TestCase):
_is_stateful = True
model_split_percents = [0.5, 0.6]
def setUp(self):
self.model_tester = Glm4ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Glm4Config, hidden_size=37)
@slow
@require_torch_large_gpu