[gemma 3] multimodal checkpoints + AutoModelForCausalLM (#36741)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Gemma3 model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -339,6 +340,18 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
def test_automodelforcausallm(self):
|
||||
"""
|
||||
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
||||
"""
|
||||
config = self.model_tester.get_config()
|
||||
model = Gemma3ForConditionalGeneration(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user