[Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config (#36684)
This commit is contained in:
@@ -73,7 +73,7 @@ if is_torch_available():
|
|||||||
}
|
}
|
||||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||||
ALL_CACHE_IMPLEMENTATIONS = (
|
ALL_CACHE_IMPLEMENTATIONS = (
|
||||||
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"]
|
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -175,6 +175,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
cache_implementation (`str`, *optional*, default to `None`):
|
cache_implementation (`str`, *optional*, default to `None`):
|
||||||
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
|
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
|
||||||
|
|
||||||
|
- `"dynamic"`: [`DynamicCache`]
|
||||||
- `"static"`: [`StaticCache`]
|
- `"static"`: [`StaticCache`]
|
||||||
- `"offloaded_static"`: [`OffloadedStaticCache`]
|
- `"offloaded_static"`: [`OffloadedStaticCache`]
|
||||||
- `"sliding_window"`: [`SlidingWindowCache`]
|
- `"sliding_window"`: [`SlidingWindowCache`]
|
||||||
@@ -182,9 +183,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
- `"mamba"`: [`MambaCache`]
|
- `"mamba"`: [`MambaCache`]
|
||||||
- `"quantized"`: [`QuantizedCache`]
|
- `"quantized"`: [`QuantizedCache`]
|
||||||
|
|
||||||
We support other cache types, but they must be manually instantiated and
|
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See
|
||||||
passed to `generate` through the `past_key_values` argument. See our
|
our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
|
||||||
[cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
|
|
||||||
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
||||||
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
||||||
it will be converted to its repsective `CacheConfig` internally.
|
it will be converted to its repsective `CacheConfig` internally.
|
||||||
|
|||||||
@@ -1177,21 +1177,37 @@ class GenerationMixin:
|
|||||||
default_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
default_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
||||||
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
|
||||||
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
|
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
|
||||||
|
"""
|
||||||
|
Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
|
||||||
|
processor/criteria is present on both lists, use the user-defined one.
|
||||||
|
|
||||||
|
(Note: up to v4.49.0, this funtion threw an exception is the same logit processor was found twice.)
|
||||||
|
"""
|
||||||
if len(custom_list) == 0:
|
if len(custom_list) == 0:
|
||||||
return default_list
|
return default_list
|
||||||
|
|
||||||
|
final_list = type(default_list)()
|
||||||
for default in default_list:
|
for default in default_list:
|
||||||
|
using_custom = False
|
||||||
for custom in custom_list:
|
for custom in custom_list:
|
||||||
if type(custom) is type(default):
|
if type(custom) is type(default):
|
||||||
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
|
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
|
||||||
raise ValueError(
|
logger.warning_once(
|
||||||
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
|
f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it "
|
||||||
f" `.generate()`, but it has already been created with the values {default}. {default} has been"
|
f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} "
|
||||||
" created by passing the corresponding arguments to generate or by the model's config default"
|
f"will take precedence. Please check the docstring of {type(custom)} to see related "
|
||||||
f" values. If you just want to change the default values of {object_type} consider passing"
|
"`.generate()` flags."
|
||||||
f" them as arguments to `.generate()` instead of using a custom {object_type}."
|
|
||||||
)
|
)
|
||||||
default_list.extend(custom_list)
|
final_list.append(custom)
|
||||||
return default_list
|
using_custom = True
|
||||||
|
break
|
||||||
|
if not using_custom:
|
||||||
|
final_list.append(default)
|
||||||
|
|
||||||
|
for custom in custom_list:
|
||||||
|
if custom not in final_list:
|
||||||
|
final_list.append(custom)
|
||||||
|
return final_list
|
||||||
|
|
||||||
def compute_transition_scores(
|
def compute_transition_scores(
|
||||||
self,
|
self,
|
||||||
@@ -1573,17 +1589,28 @@ class GenerationMixin:
|
|||||||
# exception will be raised in `_validate_model_kwargs`
|
# exception will be raised in `_validate_model_kwargs`
|
||||||
if not is_torchdynamo_compiling():
|
if not is_torchdynamo_compiling():
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
model_kwargs = generation_config.update(**kwargs)
|
|
||||||
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
|
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
|
||||||
|
# TODO (joao): per-model generation config classes.
|
||||||
if not using_model_generation_config:
|
if not using_model_generation_config:
|
||||||
if generation_config.bos_token_id is None:
|
modified_values = {}
|
||||||
generation_config.bos_token_id = self.generation_config.bos_token_id
|
default_generation_config = GenerationConfig()
|
||||||
if generation_config.eos_token_id is None:
|
for key, default_value in default_generation_config.__dict__.items():
|
||||||
generation_config.eos_token_id = self.generation_config.eos_token_id
|
if key.startswith("_"): # metadata
|
||||||
if generation_config.pad_token_id is None:
|
continue
|
||||||
generation_config.pad_token_id = self.generation_config.pad_token_id
|
custom_gen_config_value = getattr(generation_config, key)
|
||||||
if generation_config.decoder_start_token_id is None:
|
model_gen_config_value = getattr(self.generation_config, key)
|
||||||
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
|
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
||||||
|
modified_values[key] = model_gen_config_value
|
||||||
|
setattr(generation_config, key, model_gen_config_value)
|
||||||
|
if len(modified_values) > 0:
|
||||||
|
logger.warning_once(
|
||||||
|
f"`generation_config` default values have been modified to match model-specific defaults: "
|
||||||
|
f"{modified_values}. If this is not desired, please set these values explicitly."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally, apply any passed kwargs
|
||||||
|
model_kwargs = generation_config.update(**kwargs)
|
||||||
else:
|
else:
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
@@ -1837,6 +1864,8 @@ class GenerationMixin:
|
|||||||
model_kwargs[cache_name] = cache_class(cache_config)
|
model_kwargs[cache_name] = cache_class(cache_config)
|
||||||
elif generation_config.cache_implementation == "offloaded":
|
elif generation_config.cache_implementation == "offloaded":
|
||||||
model_kwargs[cache_name] = OffloadedCache()
|
model_kwargs[cache_name] = OffloadedCache()
|
||||||
|
elif generation_config.cache_implementation == "dynamic":
|
||||||
|
model_kwargs[cache_name] = DynamicCache()
|
||||||
|
|
||||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||||
# keeps copying the cache thus using much more memory
|
# keeps copying the cache thus using much more memory
|
||||||
|
|||||||
@@ -1162,8 +1162,8 @@ class GenerationTesterMixin:
|
|||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
self._check_similar_generate_outputs(low_output, high_output)
|
self._check_similar_generate_outputs(low_output, high_output)
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -261,6 +262,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
@@ -269,6 +271,7 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
@@ -81,6 +82,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
@@ -89,6 +91,7 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -299,12 +299,13 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
@@ -96,6 +97,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
@@ -104,6 +106,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -23,6 +24,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
Gemma3Config,
|
Gemma3Config,
|
||||||
Gemma3TextConfig,
|
Gemma3TextConfig,
|
||||||
|
GenerationConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -75,6 +77,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
@@ -83,6 +86,7 @@ class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
@@ -277,6 +281,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
@@ -285,6 +290,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
@@ -551,3 +557,34 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||||
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||||
|
|
||||||
|
def test_generation_beyond_sliding_window_with_generation_config(self):
|
||||||
|
"""
|
||||||
|
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
||||||
|
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||||
|
"""
|
||||||
|
model_id = "gg-hf-g/gemma-3-1b-it"
|
||||||
|
attn_implementation = "sdpa"
|
||||||
|
|
||||||
|
input_text = [
|
||||||
|
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||||
|
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||||
|
]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||||
|
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# Make sure prefill is larger than sliding window
|
||||||
|
input_size = inputs.input_ids.shape[-1]
|
||||||
|
self.assertTrue(input_size > model.config.sliding_window)
|
||||||
|
|
||||||
|
generation_config = GenerationConfig(max_new_tokens=20)
|
||||||
|
|
||||||
|
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
||||||
|
output_text = tokenizer.batch_decode(out)
|
||||||
|
|
||||||
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -351,6 +352,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
@@ -359,6 +361,7 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@@ -375,6 +377,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
def test_model_parallel_beam_search(self):
|
def test_model_parallel_beam_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
pass
|
pass
|
||||||
@@ -383,6 +386,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
|
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -423,6 +423,7 @@ class SmolVLMForConditionalGenerationModelTest(GenerationTesterMixin, ModelTeste
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
|
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
|
||||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user