Several fixes for Gemma3n (#39135)

* remove the skips

* fix the epsilon to a small value (does not make sense otherwise)

* safeguard

* overload test_eager_matches_sdpa

* Update test_modeling_common.py

* skip appropriate tests

* correct no_split_layer

* fix all devices issue

* fix backward

* fix
This commit is contained in:
Cyril Vallez
2025-07-01 10:34:53 +02:00
committed by GitHub
parent d53518c5f2
commit dbc98328da
5 changed files with 491 additions and 390 deletions

View File

@@ -39,13 +39,20 @@ from transformers.testing_utils import (
require_read_token,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_test_eager_matches_sdpa_inference,
floats_tensor,
ids_tensor,
)
from ..gemma.test_modeling_gemma import GemmaModelTester
@@ -256,6 +263,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
vocab_size=99,
vocab_size_per_layer_input=99,
hidden_size=16,
hidden_size_per_layer_input=16,
num_hidden_layers=4, # override to correctly test sharing cache pattern
num_kv_shared_layers=2, # important to override
layer_types=[
@@ -291,6 +299,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
self.vocab_size = vocab_size
self.vocab_size_per_layer_input = vocab_size_per_layer_input
self.hidden_size = hidden_size
self.hidden_size_per_layer_input = hidden_size_per_layer_input
self.num_hidden_layers = num_hidden_layers
self.num_kv_shared_layers = num_kv_shared_layers
self.layer_types = layer_types
@@ -317,7 +326,6 @@ class Gemma3nTextModelTester(GemmaModelTester):
for_causal_lm_class = Gemma3nForCausalLM
@unittest.skip("Skipped for now!")
@require_torch
class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
@@ -365,6 +373,64 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
[expected_shape] * len(iter_hidden_states),
)
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
def test_eager_matches_sdpa_inference(
self,
name,
torch_dtype,
padding_side,
use_attention_mask,
output_attentions,
enable_kernels,
):
"We need to relax a bit the `atols` for fp32 here due to the altup projections"
atols = {
("cpu", False, torch.float32): 1e-3, # this was relaxed
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-3, # this was relaxed
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-3, # this was relaxed
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-3, # this was relaxed
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
_test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols
)
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
)
def test_contrastive_generate(self):
pass
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
)
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
)
def test_contrastive_generate_low_memory(self):
pass
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with dola decoding"
)
def test_dola_decoding_sample(self):
pass
class Gemma3nVision2TextModelTester:
text_config = {"activation_sparsity_pattern": None}