diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 6f4b2eaf2a..44f96efb6d 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -85,6 +85,7 @@ class Gemma2Config(PretrainedConfig): size of the sliding window. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. ```python >>> from transformers import Gemma2Model, Gemma2Config @@ -98,7 +99,6 @@ class Gemma2Config(PretrainedConfig): model_type = "gemma2" keys_to_ignore_at_inference = ["past_key_values"] - cache_implementation = "hybrid" def __init__( self, @@ -125,6 +125,7 @@ class Gemma2Config(PretrainedConfig): sliding_window=4096, final_logit_softcapping=30.0, attn_logit_softcapping=50.0, + cache_implementation="hybrid", **kwargs, ): super().__init__( @@ -153,3 +154,4 @@ class Gemma2Config(PretrainedConfig): self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping + self.cache_implementation = cache_implementation diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 7aca665096..6decd28a4d 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -117,6 +117,7 @@ class Gemma2Config(PretrainedConfig): size of the sliding window. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. + cache_implementation (`str`, *optional*, defaults to `"hybrid"`): the cache type to be used with `generate`. ```python >>> from transformers import Gemma2Model, Gemma2Config @@ -130,7 +131,6 @@ class Gemma2Config(PretrainedConfig): model_type = "gemma2" keys_to_ignore_at_inference = ["past_key_values"] - cache_implementation = "hybrid" def __init__( self, @@ -157,6 +157,7 @@ class Gemma2Config(PretrainedConfig): sliding_window=4096, final_logit_softcapping=30.0, attn_logit_softcapping=50.0, + cache_implementation="hybrid", **kwargs, ): super().__init__( @@ -185,6 +186,7 @@ class Gemma2Config(PretrainedConfig): self.sliding_window = sliding_window self.final_logit_softcapping = final_logit_softcapping self.attn_logit_softcapping = attn_logit_softcapping + self.cache_implementation = cache_implementation class Gemma2RMSNorm(GemmaRMSNorm): diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 165d478d4f..b9feccd1f9 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -44,7 +44,9 @@ SPECIAL_CASES_TO_ALLOW = { "Qwen2Config": ["use_sliding_window"], "Qwen2MoeConfig": ["use_sliding_window"], "Qwen2VLConfig": ["use_sliding_window"], - "Gemma2Config": ["tie_word_embeddings"], + # `cache_implementation` should be in the default generation config, but we don't yet support per-model + # generation configs (TODO joao) + "Gemma2Config": ["tie_word_embeddings", "cache_implementation"], # used to compute the property `self.chunk_length` "EncodecConfig": ["overlap"], # used to compute the property `self.layers_block_type`