From dad513e0c2a93c6f261be73dd0f648acb8a25c2b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 17 Feb 2025 13:55:03 +0000 Subject: [PATCH] [generate] remove cache v4.47 deprecations (#36212) --- src/transformers/cache_utils.py | 26 +++++-------------- src/transformers/generation/utils.py | 13 ++++------ tests/models/phimoe/test_modeling_phimoe.py | 2 ++ .../qwen2_5_vl/test_modeling_qwen2_5_vl.py | 5 ++++ 4 files changed, 18 insertions(+), 28 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 427e1d4e3a..07d4654c35 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -363,8 +363,7 @@ class DynamicCache(Cache): ``` """ - @deprecate_kwarg("num_hidden_layers", version="4.47.0") - def __init__(self, num_hidden_layers: Optional[int] = None) -> None: + def __init__(self) -> None: super().__init__() self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self.key_cache: List[torch.Tensor] = [] @@ -466,10 +465,7 @@ class DynamicCache(Cache): return legacy_cache @classmethod - @deprecate_kwarg("num_hidden_layers", version="4.47.0") - def from_legacy_cache( - cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None - ) -> "DynamicCache": + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for backward compatibility.""" cache = cls() @@ -495,10 +491,7 @@ class DynamicCache(Cache): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - @deprecate_kwarg("num_hidden_layers", version="4.47.0") - def batch_split( - self, full_batch_size: int, split_size: int, num_hidden_layers: int = None - ) -> List["DynamicCache"]: + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" out = [] @@ -511,8 +504,7 @@ class DynamicCache(Cache): return out @classmethod - @deprecate_kwarg("num_hidden_layers", version="4.47.0") - def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache": + def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" cache = cls() @@ -1527,10 +1519,7 @@ class EncoderDecoderCache(Cache): self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) - @deprecate_kwarg("num_hidden_layers", version="4.47.0") - def batch_split( - self, full_batch_size: int, split_size: int, num_hidden_layers: int = None - ) -> "List[EncoderDecoderCache]": + def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" self.check_dynamic_cache(self.batch_split.__name__) @@ -1543,10 +1532,7 @@ class EncoderDecoderCache(Cache): return out @classmethod - @deprecate_kwarg("num_hidden_layers", version="4.47.0") - def from_batch_splits( - cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None - ) -> "EncoderDecoderCache": + def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" self_attention_cache = DynamicCache() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index db8bbe50e5..9760b37dea 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4520,7 +4520,7 @@ def _ranking_fast( return selected_idx -def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = None): +def _split(data, full_batch_size: int, split_size: int = None): """ Takes care of three cases: 1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim @@ -4538,7 +4538,7 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = elif isinstance(data, DynamicCache) or ( isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) ): - return data.batch_split(full_batch_size, split_size, num_hidden_layers) + return data.batch_split(full_batch_size, split_size) elif isinstance(data, tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0], tuple): @@ -4591,11 +4591,9 @@ def _split_model_inputs( keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"] non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] - num_hidden_layers = config.get_text_config().num_hidden_layers - # we split the tensors and tuples of tensors data_split_list = [ - {k: _split(model_input[k], full_batch_size, num_hidden_layers, split_size)[i] for k in non_bool_keys} + {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys} for i in range(full_batch_size // split_size) ] # bool values are the same and replicated for each split @@ -4632,7 +4630,6 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf # Infer the class from the first object in the list model_output_cls = type(model_outputs[0]) - num_hidden_layers = config.get_text_config().num_hidden_layers # Ensure all objects are of the same type if not all(isinstance(obj, model_output_cls) for obj in model_outputs): @@ -4649,9 +4646,9 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf return torch.cat(data, dim=0) # New cache format elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + return DynamicCache.from_batch_splits(data) elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + return EncoderDecoderCache.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/tests/models/phimoe/test_modeling_phimoe.py b/tests/models/phimoe/test_modeling_phimoe.py index 40448a0a85..b3dc1eba68 100644 --- a/tests/models/phimoe/test_modeling_phimoe.py +++ b/tests/models/phimoe/test_modeling_phimoe.py @@ -22,6 +22,7 @@ from parameterized import parameterized from transformers import PhimoeConfig, StaticCache, is_torch_available, set_seed from transformers.testing_utils import ( + is_flaky, require_torch, slow, torch_device, @@ -449,6 +450,7 @@ class PhimoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) @parameterized.expand([("longrope",)]) + @is_flaky() # TODO (joao): unify rope tests in the mixin def test_model_rope_scaling_short_long_factor(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() n_factors = config.hidden_size // config.num_key_value_heads // 2 diff --git a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py index cfcfd3a620..dcb0816a0d 100644 --- a/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py +++ b/tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py @@ -27,6 +27,7 @@ from transformers import ( is_vision_available, ) from transformers.testing_utils import ( + is_flaky, require_flash_attn, require_torch, require_torch_gpu, @@ -347,6 +348,10 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test def test_generate_compile_fullgraph(self): pass + @is_flaky() # TODO (joao/raushan): Investigate why this test is flaky on this model + def test_prompt_lookup_decoding_matches_greedy_search(self): + super().test_prompt_lookup_decoding_matches_greedy_search() + @require_torch class Qwen2_5_VLIntegrationTest(unittest.TestCase):