fix t5gemma tests (#39052)

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-06-26 18:48:14 +02:00
committed by GitHub
parent 23b7e73f05
commit 2f50230c59
4 changed files with 27 additions and 6 deletions

View File

@@ -3748,7 +3748,7 @@ class ModelTesterMixin:
self.skipTest(
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
)
if config.model_type in ["modernbert", "gemma3"]:
if config.model_type in ["modernbert", "gemma3", "t5gemma"]:
self.skipTest(
reason=f"{config.model_type} currently (transformers==4.52.0) automatically adds an attention_mask input"
)
@@ -4414,6 +4414,10 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
cls = self._torch_compile_train_cls
attn_implementation = getattr(self, "_torch_compile_train_attn_implementation", None)
if attn_implementation is not None:
config._attn_implementation = attn_implementation
model = cls(config).to(torch_device)
inputs = {