🚨 🚨 Inherited CausalLM Tests (#37590)
* stash commit * Experiment 1: Try just Gemma * Experiment 1: Just try Gemma * make fixup * Trigger tests * stash commit * Try adding Gemma3 as well * make fixup * Correct attrib names * Correct pipeline model mapping * Add in all_model_classes for Gemma1 again * Move the pipeline model mapping around again * make fixup * Revert Gemma3 changes since it's a VLM * Let's try Falcon * Correct attributes * Correct attributes * Let's try just overriding get_config() for now * Do Nemotron too * And Llama! * Do llama/persimmon * Correctly skip tests * Fix Persimmon * Include Phimoe * Fix Gemma2 * Set model_tester_class correctly * Add GLM * More models! * models models models * make fixup * Add Qwen3 + Qwen3MoE * Correct import * make fixup * Add the QuestionAnswering classes * Add the QuestionAnswering classes * Move pipeline mapping to the right place * Jetmoe too * Stop RoPE testing models with no RoPE * Fix up JetMOE a bit * Fix up JetMOE a bit * Can we just force pad_token_id all the time? * make fixup * fix starcoder2 * Move pipeline mapping * Fix RoPE skipping * Fix RecurrentGemma tests * Fix Falcon tests * Add MoE attributes * Fix values for RoPE testing * Make sure we set bos_token_id and eos_token_id in an appropriate range * make fixup * Fix GLM4 * Add mamba attributes * Revert bits of JetMOE * Re-add the JetMOE skips * Update tests/causal_lm_tester.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add licence --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -28,8 +28,7 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -43,17 +42,18 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
|
||||
class Glm4ModelTester(GemmaModelTester):
|
||||
class Glm4ModelTester(CausalLMModelTester):
|
||||
if is_torch_available():
|
||||
config_class = Glm4Config
|
||||
model_class = Glm4Model
|
||||
for_causal_lm_class = Glm4ForCausalLM
|
||||
for_sequence_class = Glm4ForSequenceClassification
|
||||
for_token_class = Glm4ForTokenClassification
|
||||
base_model_class = Glm4Model
|
||||
causal_lm_class = Glm4ForCausalLM
|
||||
sequence_classification_class = Glm4ForSequenceClassification
|
||||
token_classification_class = Glm4ForTokenClassification
|
||||
|
||||
|
||||
@require_torch
|
||||
class Glm4ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
class Glm4ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
model_tester_class = Glm4ModelTester
|
||||
all_model_classes = (
|
||||
(Glm4Model, Glm4ForCausalLM, Glm4ForSequenceClassification, Glm4ForTokenClassification)
|
||||
if is_torch_available()
|
||||
@@ -75,10 +75,6 @@ class Glm4ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
_is_stateful = True
|
||||
model_split_percents = [0.5, 0.6]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Glm4ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Glm4Config, hidden_size=37)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_large_gpu
|
||||
|
||||
Reference in New Issue
Block a user