From 9470d6532436e9db2951a196effd6f8841befb76 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 20 Nov 2024 07:46:35 +0100 Subject: [PATCH] Fix low memory beam search (#34746) * fix * higher max positions in tests --- src/transformers/cache_utils.py | 12 +++++++++--- tests/generation/test_utils.py | 1 - tests/models/blip_2/test_modeling_blip_2.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3b491b0460..490280ce81 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -528,7 +528,7 @@ class DynamicCache(Cache): cache = cls() for idx in range(len(splits[0])): key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] - value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []] if key_cache != []: layer_keys = torch.cat(key_cache, dim=0) layer_values = torch.cat(value_cache, dim=0) @@ -1523,7 +1523,10 @@ class EncoderDecoderCache(Cache): self.check_dynamic_cache(self.crop.__name__) self.self_attention_cache.crop(maximum_length) - def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def batch_split( + self, full_batch_size: int, split_size: int, num_hidden_layers: int = None + ) -> "List[EncoderDecoderCache]": """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" self.check_dynamic_cache(self.batch_split.__name__) @@ -1536,7 +1539,10 @@ class EncoderDecoderCache(Cache): return out @classmethod - def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": + @deprecate_kwarg("num_hidden_layers", version="4.47.0") + def from_batch_splits( + cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None + ) -> "EncoderDecoderCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" self_attention_cache = DynamicCache() diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b1d0042c65..76dc23ed9b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1046,7 +1046,6 @@ class GenerationTesterMixin: self.assertListEqual(low_output.tolist(), high_output.tolist()) @pytest.mark.generate - @unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703") def test_beam_search_low_memory(self): # Check that choosing 'low_memory' does not change the model output for model_class in self.all_generative_model_classes: diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 0943661b96..a141ef40be 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -330,7 +330,7 @@ class Blip2TextModelDecoderOnlyTester: hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=512, eos_token_id=2, pad_token_id=1, bos_token_id=0,