fix t5gemma tests (#39052)
* fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -595,6 +595,11 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = T5GemmaForConditionalGeneration if is_torch_available() else None
|
||||
# `t5gemma` will give warning or raise error if it is not `eager` during training.
|
||||
_torch_compile_train_attn_implementation = "eager"
|
||||
|
||||
# won't fix
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = T5GemmaModelTester(self)
|
||||
@@ -1584,6 +1589,9 @@ class T5GemmaEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
model_split_percents = [0.4, 0.5]
|
||||
|
||||
# won't fix
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = T5GemmaEncoderOnlyModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
|
||||
@@ -3748,7 +3748,7 @@ class ModelTesterMixin:
|
||||
self.skipTest(
|
||||
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
|
||||
)
|
||||
if config.model_type in ["modernbert", "gemma3"]:
|
||||
if config.model_type in ["modernbert", "gemma3", "t5gemma"]:
|
||||
self.skipTest(
|
||||
reason=f"{config.model_type} currently (transformers==4.52.0) automatically adds an attention_mask input"
|
||||
)
|
||||
@@ -4414,6 +4414,10 @@ class ModelTesterMixin:
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
cls = self._torch_compile_train_cls
|
||||
attn_implementation = getattr(self, "_torch_compile_train_attn_implementation", None)
|
||||
if attn_implementation is not None:
|
||||
config._attn_implementation = attn_implementation
|
||||
|
||||
model = cls(config).to(torch_device)
|
||||
|
||||
inputs = {
|
||||
|
||||
Reference in New Issue
Block a user