[generate] handle support for cache classes when num enc layers != num dec layers (#40277)

* handle support for cache classes when num enc layers != num dec layers

* handle overwrites

* one more corner case

* Update src/transformers/generation/utils.py

* Update src/transformers/generation/utils.py

* Apply suggestions from code review

* handle corner case :o
This commit is contained in:
Joao Gante
2025-08-21 17:35:18 +01:00
committed by GitHub
parent 7f38068ae0
commit 9568b506ed
9 changed files with 89 additions and 24 deletions

View File

@@ -1168,21 +1168,34 @@ class PretrainedConfig(PushToHubMixin):
return non_default_generation_parameters
def get_text_config(self, decoder=False) -> "PretrainedConfig":
def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig":
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
`decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
which is useful on models that have both text input and output modalities.
There are three possible outcomes of using this method:
1. On most models, it returns the original config instance itself.
2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
of valid names.
3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
decoder (`Optional[bool]`, *optional*):
If set to `True`, then only search for decoder config names.
encoder (`Optional[bool]`, *optional*):
If set to `True`, then only search for encoder config names.
"""
return_both = decoder == encoder # both unset or both set -> search all possible names
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
encoder_possible_text_config_names = ("text_encoder",)
if decoder:
if return_both:
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
elif decoder:
possible_text_config_names = decoder_possible_text_config_names
else:
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
possible_text_config_names = encoder_possible_text_config_names
valid_text_config_names = []
for text_config_name in possible_text_config_names:
@@ -1194,12 +1207,27 @@ class PretrainedConfig(PushToHubMixin):
if len(valid_text_config_names) > 1:
raise ValueError(
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
"case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
"e.g. `text_config = config.sub_config_name`"
)
elif len(valid_text_config_names) == 1:
config_to_return = getattr(self, valid_text_config_names[0])
else:
config_to_return = self
# handle legacy models with flat config structure, when we only want one of the configs
if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
config_to_return = copy.deepcopy(config_to_return)
prefix_to_discard = "encoder" if decoder else "decoder"
for key in config_to_return.to_dict():
if key.startswith(prefix_to_discard):
delattr(config_to_return, key)
# old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers"
if decoder and hasattr(config_to_return, "decoder_layers"):
config_to_return.num_hidden_layers = config_to_return.decoder_layers
elif encoder and hasattr(config_to_return, "encoder_layers"):
config_to_return.num_hidden_layers = config_to_return.encoder_layers
return config_to_return
@classmethod

View File

@@ -1844,12 +1844,19 @@ class GenerationMixin(ContinuousMixin):
)
if need_new_cache:
cache_kwargs = {"config": self.config, "max_cache_len": max_cache_len, "offloading": offload_cache}
self._cache = StaticCache(**cache_kwargs)
self_attention_cache_kwargs = {
"config": self.config.get_text_config(decoder=True),
"max_cache_len": max_cache_len,
"offloading": offload_cache,
}
self._cache = StaticCache(**self_attention_cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
self._cache = EncoderDecoderCache(self._cache, StaticCache(**encoder_kwargs))
cross_attention_cache_kwargs = {
"config": self.config.get_text_config(encoder=True),
"max_cache_len": model_kwargs["encoder_outputs"][0].shape[1],
"offloading": offload_cache,
}
self._cache = EncoderDecoderCache(self._cache, StaticCache(**cross_attention_cache_kwargs))
else:
self._cache.reset()
return self._cache

View File

@@ -87,8 +87,8 @@ class ColQwen2Config(PretrainedConfig):
self.initializer_range = initializer_range
super().__init__(**kwargs)
def get_text_config(self, decoder=False) -> PretrainedConfig:
return self.vlm_config.get_text_config(decoder=decoder)
def get_text_config(self, *args, **kwargs) -> PretrainedConfig:
return self.vlm_config.get_text_config(*args, **kwargs)
__all__ = ["ColQwen2Config"]

View File

@@ -368,7 +368,7 @@ class DiaConfig(PretrainedConfig):
**kwargs,
)
def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
"""Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
return self.decoder_config

View File

@@ -1073,7 +1073,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
super().__init__(**kwargs)
def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
@@ -1085,7 +1085,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
return self.thinker_config.get_text_config(*args, **kwargs)
__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]

View File

@@ -1108,7 +1108,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
super().__init__(**kwargs)
def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
"""
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
@@ -1120,7 +1120,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
# except for Qwen yet. This has to be generalized if more deeply nested configs are
# added. NOTE: currently method used only by vLLM
return self.thinker_config.get_text_config()
return self.thinker_config.get_text_config(*args, **kwargs)
class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):

View File

@@ -323,9 +323,8 @@ class T5GemmaConfig(PretrainedConfig):
setattr(self.decoder, key, value)
super().__setattr__(key, value)
def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
del decoder
return self

View File

@@ -213,9 +213,8 @@ class T5GemmaConfig(PretrainedConfig):
setattr(self.decoder, key, value)
super().__setattr__(key, value)
def get_text_config(self, decoder=False):
def get_text_config(self, *args, **kwargs):
# Always return self, regardless of the decoder option.
del decoder
return self

View File

@@ -24,7 +24,7 @@ from pathlib import Path
from huggingface_hub import HfFolder
from requests.exceptions import HTTPError
from transformers import AutoConfig, BertConfig, GPT2Config
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test
@@ -300,3 +300,35 @@ class ConfigTestUtils(unittest.TestCase):
with warnings.catch_warnings():
warnings.simplefilter("error")
PretrainedConfig.from_pretrained("bert-base-uncased")
def test_get_text_config(self):
"""Tests the `get_text_config` method."""
# 1. model with only text input -> returns the original config instance
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.assertEqual(config.get_text_config(), config)
self.assertEqual(config.get_text_config(decoder=True), config)
# 2. composite model (VLM) -> returns the text component
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration")
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
# 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above)
config = Florence2Config()
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
# 4. old composite model -> may remove components based on the `decoder` or `encoder` argument
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart")
self.assertEqual(config.get_text_config(), config)
# both encoder_layers and decoder_layers exist
self.assertTrue(getattr(config, "encoder_layers", None) is not None)
self.assertTrue(getattr(config, "decoder_layers", None) is not None)
decoder_config = config.get_text_config(decoder=True)
self.assertNotEqual(decoder_config, config)
self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers)
self.assertTrue(getattr(decoder_config, "encoder_layers", None) is None) # encoder_layers is removed
encoder_config = config.get_text_config(encoder=True)
self.assertNotEqual(encoder_config, config)
self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers)
self.assertTrue(getattr(encoder_config, "decoder_layers", None) is None) # decoder_layers is removed