[generate] handle support for cache classes when num enc layers != num dec layers (#40277)
* handle support for cache classes when num enc layers != num dec layers * handle overwrites * one more corner case * Update src/transformers/generation/utils.py * Update src/transformers/generation/utils.py * Apply suggestions from code review * handle corner case :o
This commit is contained in:
@@ -1168,21 +1168,34 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
return non_default_generation_parameters
|
||||
|
||||
def get_text_config(self, decoder=False) -> "PretrainedConfig":
|
||||
def get_text_config(self, decoder=None, encoder=None) -> "PretrainedConfig":
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
Returns the text config related to the text input (encoder) or text output (decoder) of the model. The
|
||||
`decoder` and `encoder` input arguments can be used to specify which end of the model we are interested in,
|
||||
which is useful on models that have both text input and output modalities.
|
||||
|
||||
There are three possible outcomes of using this method:
|
||||
1. On most models, it returns the original config instance itself.
|
||||
2. On newer (2024+) composite models, it returns the text section of the config, which is nested under a set
|
||||
of valid names.
|
||||
3. On older (2023-) composite models, it discards decoder-only parameters when `encoder=True` and vice-versa.
|
||||
|
||||
Args:
|
||||
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||
decoder (`Optional[bool]`, *optional*):
|
||||
If set to `True`, then only search for decoder config names.
|
||||
encoder (`Optional[bool]`, *optional*):
|
||||
If set to `True`, then only search for encoder config names.
|
||||
"""
|
||||
return_both = decoder == encoder # both unset or both set -> search all possible names
|
||||
|
||||
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
encoder_possible_text_config_names = ("text_encoder",)
|
||||
if decoder:
|
||||
if return_both:
|
||||
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
|
||||
elif decoder:
|
||||
possible_text_config_names = decoder_possible_text_config_names
|
||||
else:
|
||||
possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
|
||||
possible_text_config_names = encoder_possible_text_config_names
|
||||
|
||||
valid_text_config_names = []
|
||||
for text_config_name in possible_text_config_names:
|
||||
@@ -1194,12 +1207,27 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if len(valid_text_config_names) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
|
||||
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
||||
"case, using `get_text_config()` would be ambiguous. Please specify the desired text config directly, "
|
||||
"e.g. `text_config = config.sub_config_name`"
|
||||
)
|
||||
elif len(valid_text_config_names) == 1:
|
||||
config_to_return = getattr(self, valid_text_config_names[0])
|
||||
else:
|
||||
config_to_return = self
|
||||
|
||||
# handle legacy models with flat config structure, when we only want one of the configs
|
||||
if not return_both and len(valid_text_config_names) == 0 and config_to_return.is_encoder_decoder:
|
||||
config_to_return = copy.deepcopy(config_to_return)
|
||||
prefix_to_discard = "encoder" if decoder else "decoder"
|
||||
for key in config_to_return.to_dict():
|
||||
if key.startswith(prefix_to_discard):
|
||||
delattr(config_to_return, key)
|
||||
# old encoder/decoder models may use "encoder_layers"/"decoder_layers" instead of "num_hidden_layers"
|
||||
if decoder and hasattr(config_to_return, "decoder_layers"):
|
||||
config_to_return.num_hidden_layers = config_to_return.decoder_layers
|
||||
elif encoder and hasattr(config_to_return, "encoder_layers"):
|
||||
config_to_return.num_hidden_layers = config_to_return.encoder_layers
|
||||
|
||||
return config_to_return
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1844,12 +1844,19 @@ class GenerationMixin(ContinuousMixin):
|
||||
)
|
||||
|
||||
if need_new_cache:
|
||||
cache_kwargs = {"config": self.config, "max_cache_len": max_cache_len, "offloading": offload_cache}
|
||||
self._cache = StaticCache(**cache_kwargs)
|
||||
self_attention_cache_kwargs = {
|
||||
"config": self.config.get_text_config(decoder=True),
|
||||
"max_cache_len": max_cache_len,
|
||||
"offloading": offload_cache,
|
||||
}
|
||||
self._cache = StaticCache(**self_attention_cache_kwargs)
|
||||
if requires_cross_attention_cache:
|
||||
encoder_kwargs = cache_kwargs.copy()
|
||||
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
|
||||
self._cache = EncoderDecoderCache(self._cache, StaticCache(**encoder_kwargs))
|
||||
cross_attention_cache_kwargs = {
|
||||
"config": self.config.get_text_config(encoder=True),
|
||||
"max_cache_len": model_kwargs["encoder_outputs"][0].shape[1],
|
||||
"offloading": offload_cache,
|
||||
}
|
||||
self._cache = EncoderDecoderCache(self._cache, StaticCache(**cross_attention_cache_kwargs))
|
||||
else:
|
||||
self._cache.reset()
|
||||
return self._cache
|
||||
|
||||
@@ -87,8 +87,8 @@ class ColQwen2Config(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder=False) -> PretrainedConfig:
|
||||
return self.vlm_config.get_text_config(decoder=decoder)
|
||||
def get_text_config(self, *args, **kwargs) -> PretrainedConfig:
|
||||
return self.vlm_config.get_text_config(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["ColQwen2Config"]
|
||||
|
||||
@@ -368,7 +368,7 @@ class DiaConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
def get_text_config(self, *args, **kwargs):
|
||||
"""Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
|
||||
return self.decoder_config
|
||||
|
||||
|
||||
@@ -1073,7 +1073,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
def get_text_config(self, *args, **kwargs):
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
@@ -1085,7 +1085,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
|
||||
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
|
||||
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
||||
# added. NOTE: currently method used only by vLLM
|
||||
return self.thinker_config.get_text_config()
|
||||
return self.thinker_config.get_text_config(*args, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5OmniConfig", "Qwen2_5OmniThinkerConfig", "Qwen2_5OmniTalkerConfig", "Qwen2_5OmniToken2WavConfig"]
|
||||
|
||||
@@ -1108,7 +1108,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
def get_text_config(self, *args, **kwargs):
|
||||
"""
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
@@ -1120,7 +1120,7 @@ class Qwen2_5OmniConfig(PretrainedConfig):
|
||||
# Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model
|
||||
# except for Qwen yet. This has to be generalized if more deeply nested configs are
|
||||
# added. NOTE: currently method used only by vLLM
|
||||
return self.thinker_config.get_text_config()
|
||||
return self.thinker_config.get_text_config(*args, **kwargs)
|
||||
|
||||
|
||||
class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
@@ -323,9 +323,8 @@ class T5GemmaConfig(PretrainedConfig):
|
||||
setattr(self.decoder, key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
def get_text_config(self, *args, **kwargs):
|
||||
# Always return self, regardless of the decoder option.
|
||||
del decoder
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -213,9 +213,8 @@ class T5GemmaConfig(PretrainedConfig):
|
||||
setattr(self.decoder, key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
def get_text_config(self, *args, **kwargs):
|
||||
# Always return self, regardless of the decoder option.
|
||||
del decoder
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from pathlib import Path
|
||||
from huggingface_hub import HfFolder
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import AutoConfig, BertConfig, GPT2Config
|
||||
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test
|
||||
|
||||
@@ -300,3 +300,35 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
PretrainedConfig.from_pretrained("bert-base-uncased")
|
||||
|
||||
def test_get_text_config(self):
|
||||
"""Tests the `get_text_config` method."""
|
||||
# 1. model with only text input -> returns the original config instance
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
|
||||
self.assertEqual(config.get_text_config(), config)
|
||||
self.assertEqual(config.get_text_config(decoder=True), config)
|
||||
|
||||
# 2. composite model (VLM) -> returns the text component
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration")
|
||||
self.assertEqual(config.get_text_config(), config.text_config)
|
||||
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
|
||||
|
||||
# 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above)
|
||||
config = Florence2Config()
|
||||
self.assertEqual(config.get_text_config(), config.text_config)
|
||||
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
|
||||
|
||||
# 4. old composite model -> may remove components based on the `decoder` or `encoder` argument
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
self.assertEqual(config.get_text_config(), config)
|
||||
# both encoder_layers and decoder_layers exist
|
||||
self.assertTrue(getattr(config, "encoder_layers", None) is not None)
|
||||
self.assertTrue(getattr(config, "decoder_layers", None) is not None)
|
||||
decoder_config = config.get_text_config(decoder=True)
|
||||
self.assertNotEqual(decoder_config, config)
|
||||
self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers)
|
||||
self.assertTrue(getattr(decoder_config, "encoder_layers", None) is None) # encoder_layers is removed
|
||||
encoder_config = config.get_text_config(encoder=True)
|
||||
self.assertNotEqual(encoder_config, config)
|
||||
self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers)
|
||||
self.assertTrue(getattr(encoder_config, "decoder_layers", None) is None) # decoder_layers is removed
|
||||
|
||||
Reference in New Issue
Block a user