🚨 🚨 Inherited CausalLM Tests (#37590)

* stash commit

* Experiment 1: Try just Gemma

* Experiment 1: Just try Gemma

* make fixup

* Trigger tests

* stash commit

* Try adding Gemma3 as well

* make fixup

* Correct attrib names

* Correct pipeline model mapping

* Add in all_model_classes for Gemma1 again

* Move the pipeline model mapping around again

* make fixup

* Revert Gemma3 changes since it's a VLM

* Let's try Falcon

* Correct attributes

* Correct attributes

* Let's try just overriding get_config() for now

* Do Nemotron too

* And Llama!

* Do llama/persimmon

* Correctly skip tests

* Fix Persimmon

* Include Phimoe

* Fix Gemma2

* Set model_tester_class correctly

* Add GLM

* More models!

* models models models

* make fixup

* Add Qwen3 + Qwen3MoE

* Correct import

* make fixup

* Add the QuestionAnswering classes

* Add the QuestionAnswering classes

* Move pipeline mapping to the right place

* Jetmoe too

* Stop RoPE testing models with no RoPE

* Fix up JetMOE a bit

* Fix up JetMOE a bit

* Can we just force pad_token_id all the time?

* make fixup

* fix starcoder2

* Move pipeline mapping

* Fix RoPE skipping

* Fix RecurrentGemma tests

* Fix Falcon tests

* Add MoE attributes

* Fix values for RoPE testing

* Make sure we set bos_token_id and eos_token_id in an appropriate range

* make fixup

* Fix GLM4

* Add mamba attributes

* Revert bits of JetMOE

* Re-add the JetMOE skips

* Update tests/causal_lm_tester.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add licence

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Matt
2025-05-23 18:29:31 +01:00
committed by GitHub
parent d5f992f5e6
commit 53fb245eb6
25 changed files with 816 additions and 4422 deletions

View File

@@ -16,6 +16,7 @@
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
@@ -27,151 +28,26 @@ from transformers.testing_utils import (
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import RecurrentGemmaForCausalLM, RecurrentGemmaModel
from transformers import RecurrentGemmaConfig, RecurrentGemmaForCausalLM, RecurrentGemmaModel
class RecurrentGemmaModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=12,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
num_hidden_layers=3,
vocab_size=99,
hidden_size=32,
intermediate_size=3 * 32,
num_attention_heads=2,
lru_width=2 * 32,
embeddings_scale_by_sqrt_dim=True,
attention_window_size=16,
conv1d_width=4,
logits_soft_cap=30.0,
rms_norm_eps=1e-6,
use_cache=True,
rope_theta=10000.0,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim
self.attention_window_size = attention_window_size
self.conv1d_width = conv1d_width
self.logits_soft_cap = logits_soft_cap
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return RecurrentGemmaConfig(
num_hidden_layers=self.num_hidden_layers,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_attention_heads=self.num_attention_heads,
lru_width=self.lru_width,
embeddings_scale_by_sqrt_dim=self.embeddings_scale_by_sqrt_dim,
attention_window_size=self.attention_window_size,
conv1d_width=self.conv1d_width,
logits_soft_cap=self.logits_soft_cap,
rms_norm_eps=self.rms_norm_eps,
use_cache=self.use_cache,
rope_theta=self.rope_theta,
pad_token_id=self.pad_token_id,
output_attentions=False,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->RecurrentGemma
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = RecurrentGemmaModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->RecurrentGemma
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
class RecurrentGemmaModelTester(CausalLMModelTester):
config_class = RecurrentGemmaConfig
if is_torch_available():
base_model_class = RecurrentGemmaModel
causal_lm_class = RecurrentGemmaForCausalLM
@require_torch
class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else ()
# Doesn't run generation tests. TODO @gante not fully supported
all_generative_model_classes = ()
class RecurrentGemmaModelTest(CausalLMModelTest, unittest.TestCase):
all_model_classes = (RecurrentGemmaModel, RecurrentGemmaForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": RecurrentGemmaModel,
@@ -180,48 +56,10 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
if is_torch_available()
else {}
)
fx_compatible = False # FIXME let's try to support this @ArthurZucker
test_torchscript = False # FIXME let's try to support this @ArthurZucker
test_missing_keys = False
test_model_parallel = False
test_headmasking = False
test_pruning = False
test_head_masking = False # RecurrentGemma does not have attention heads
# Need to remove 0.9 in `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.6]
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
config_class,
model_architecture,
tokenizer_name,
image_processor_name,
feature_extractor_name,
processor_name,
):
return True
def setUp(self):
# We don't output attentions
self.has_attentions = False
self.model_tester = RecurrentGemmaModelTester(self)
self.config_tester = ConfigTester(self, config_class=RecurrentGemmaConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
has_attentions = False
model_tester_class = RecurrentGemmaModelTester
@unittest.skip(reason="RecurrentGemma only supports sdpa")
def test_eager_matches_sdpa_generate(self):
@@ -255,6 +93,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_model_parallel_beam_search(self):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self):
@@ -273,6 +112,65 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_initialization(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_beam_sample_generate_dict_output(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_constrained_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_dola_decoding_sample(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_group_beam_search_generate(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_group_beam_search_generate_dict_output(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_constrained_beam_search_generate(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_greedy_generate_dict_outputs(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
@pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
def test_model_outputs_equivalence(self):
pass
@require_torch_accelerator
@slow