[gemma 3] multimodal checkpoints + AutoModelForCausalLM (#36741)

This commit is contained in:
Joao Gante
2025-03-19 15:04:19 +00:00
committed by GitHub
parent b11050d6a2
commit 7c233980f4
3 changed files with 34 additions and 0 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Gemma3 model."""
import tempfile
import unittest
import pytest
@@ -339,6 +340,18 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_flex_attention_with_grads(self):
pass
def test_automodelforcausallm(self):
"""
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
"""
config = self.model_tester.get_config()
model = Gemma3ForConditionalGeneration(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
@slow
@require_torch_gpu