Several fixes for Gemma3n (#39135)
* remove the skips * fix the epsilon to a small value (does not make sense otherwise) * safeguard * overload test_eager_matches_sdpa * Update test_modeling_common.py * skip appropriate tests * correct no_split_layer * fix all devices issue * fix backward * fix
This commit is contained in:
@@ -39,13 +39,20 @@ from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
_test_eager_matches_sdpa_inference,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
from ..gemma.test_modeling_gemma import GemmaModelTester
|
||||
|
||||
|
||||
@@ -256,6 +263,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
|
||||
vocab_size=99,
|
||||
vocab_size_per_layer_input=99,
|
||||
hidden_size=16,
|
||||
hidden_size_per_layer_input=16,
|
||||
num_hidden_layers=4, # override to correctly test sharing cache pattern
|
||||
num_kv_shared_layers=2, # important to override
|
||||
layer_types=[
|
||||
@@ -291,6 +299,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_size_per_layer_input = vocab_size_per_layer_input
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_size_per_layer_input = hidden_size_per_layer_input
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_kv_shared_layers = num_kv_shared_layers
|
||||
self.layer_types = layer_types
|
||||
@@ -317,7 +326,6 @@ class Gemma3nTextModelTester(GemmaModelTester):
|
||||
for_causal_lm_class = Gemma3nForCausalLM
|
||||
|
||||
|
||||
@unittest.skip("Skipped for now!")
|
||||
@require_torch
|
||||
class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
|
||||
@@ -365,6 +373,64 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self,
|
||||
name,
|
||||
torch_dtype,
|
||||
padding_side,
|
||||
use_attention_mask,
|
||||
output_attentions,
|
||||
enable_kernels,
|
||||
):
|
||||
"We need to relax a bit the `atols` for fp32 here due to the altup projections"
|
||||
atols = {
|
||||
("cpu", False, torch.float32): 1e-3, # this was relaxed
|
||||
("cpu", False, torch.float16): 5e-3,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-3, # this was relaxed
|
||||
("cpu", True, torch.float16): 5e-3,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-3, # this was relaxed
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-3, # this was relaxed
|
||||
("cuda", True, torch.bfloat16): 1e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
_test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
|
||||
)
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
|
||||
)
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
|
||||
)
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with dola decoding"
|
||||
)
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma3nVision2TextModelTester:
|
||||
text_config = {"activation_sparsity_pattern": None}
|
||||
|
||||
Reference in New Issue
Block a user