[Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config (#36684)

This commit is contained in:
Joao Gante
2025-03-15 12:40:09 +00:00
committed by GitHub
parent f263e88dcf
commit fc8764c9a6
11 changed files with 108 additions and 24 deletions

View File

@@ -16,6 +16,8 @@
import unittest
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
@@ -375,6 +377,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_model_parallel_beam_search(self):
pass
@pytest.mark.generate
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self):
pass
@@ -383,6 +386,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_left_padding_compatibility(self):
pass
@pytest.mark.generate
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
def test_assisted_decoding_sample(self):
pass