From a464afbe2ac6b1d866d6850fa509f9e8dd244a92 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 25 Nov 2024 23:59:38 +0800 Subject: [PATCH] fix static cache data type miss-match (#34799) * fix gptj data type missmatch Signed-off-by: jiqing-feng * add low precision static cache tests Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng * fix low-precision static cache tests * fix format Signed-off-by: jiqing-feng * avoid config change Signed-off-by: jiqing-feng * change data type convert in cache copy Signed-off-by: jiqing-feng * fix comment Signed-off-by: jiqing-feng * cast key value after k v out Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng --- src/transformers/cache_utils.py | 2 ++ tests/generation/test_utils.py | 59 ++++++++++++++++++--------------- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 490280ce81..f3f0bd6fe5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1217,6 +1217,8 @@ class StaticCache(Cache): k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) if cache_position is None: k_out.copy_(key_states) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6c9a4801b6..0605ea7939 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1901,36 +1901,41 @@ class GenerationTesterMixin: seq_length = main_input.shape[-1] max_new_tokens = 20 - model = model_class(config).to(torch_device).eval() - generation_kwargs = { - "max_new_tokens": max_new_tokens, - "return_dict_in_generate": True, # Required to return `past_key_values` - "output_scores": True, - "use_cache": True, - } + for dtype in (torch.float32, torch.float16): + model = model_class(config).to(torch_device).to(dtype).eval() + generation_kwargs = { + "max_new_tokens": max_new_tokens, + "return_dict_in_generate": True, # Required to return `past_key_values` + "output_scores": True, + "use_cache": True, + } - static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static") + static_cache_generation = model.generate( + **generation_kwargs, **inputs_dict, cache_implementation="static" + ) - # Check 1: The cache shapes must match the expected shapes - max_cache_len = seq_length + max_new_tokens - config = config.text_config if hasattr(config, "text_config") else config - head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - num_hidden_layers = config.num_hidden_layers - cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) - self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) - self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) + # Check 1: The cache shapes must match the expected shapes + max_cache_len = seq_length + max_new_tokens + text_config = config.text_config if hasattr(config, "text_config") else config + head_dim = ( + text_config.head_dim + if hasattr(text_config, "head_dim") + else text_config.hidden_size // text_config.num_attention_heads + ) + num_key_value_heads = ( + text_config.num_attention_heads + if getattr(text_config, "num_key_value_heads", None) is None + else text_config.num_key_value_heads + ) + num_hidden_layers = text_config.num_hidden_layers + cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) + self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) + self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) + self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) - # Check 2: The outputs must be similar to the case with dynamic cache - dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) - self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation) + # Check 2: The outputs must be similar to the case with dynamic cache + dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) + self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation) @require_optimum_quanto @pytest.mark.generate