fix static cache data type miss-match (#34799)

* fix gptj data type missmatch

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add low precision static cache tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix low-precision static cache tests

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* avoid config change

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* change data type convert in cache copy

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* cast key value after k v out

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng
2024-11-25 23:59:38 +08:00
committed by GitHub
parent b13916c09d
commit a464afbe2a
2 changed files with 34 additions and 27 deletions

View File

@@ -1217,6 +1217,8 @@ class StaticCache(Cache):
k_out = self.key_cache[layer_idx] k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx] v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if cache_position is None: if cache_position is None:
k_out.copy_(key_states) k_out.copy_(key_states)

View File

@@ -1901,36 +1901,41 @@ class GenerationTesterMixin:
seq_length = main_input.shape[-1] seq_length = main_input.shape[-1]
max_new_tokens = 20 max_new_tokens = 20
model = model_class(config).to(torch_device).eval() for dtype in (torch.float32, torch.float16):
generation_kwargs = { model = model_class(config).to(torch_device).to(dtype).eval()
"max_new_tokens": max_new_tokens, generation_kwargs = {
"return_dict_in_generate": True, # Required to return `past_key_values` "max_new_tokens": max_new_tokens,
"output_scores": True, "return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True, "output_scores": True,
} "use_cache": True,
}
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static") static_cache_generation = model.generate(
**generation_kwargs, **inputs_dict, cache_implementation="static"
)
# Check 1: The cache shapes must match the expected shapes # Check 1: The cache shapes must match the expected shapes
max_cache_len = seq_length + max_new_tokens max_cache_len = seq_length + max_new_tokens
config = config.text_config if hasattr(config, "text_config") else config text_config = config.text_config if hasattr(config, "text_config") else config
head_dim = ( head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads text_config.head_dim
) if hasattr(text_config, "head_dim")
num_key_value_heads = ( else text_config.hidden_size // text_config.num_attention_heads
config.num_attention_heads )
if getattr(config, "num_key_value_heads", None) is None num_key_value_heads = (
else config.num_key_value_heads text_config.num_attention_heads
) if getattr(text_config, "num_key_value_heads", None) is None
num_hidden_layers = config.num_hidden_layers else text_config.num_key_value_heads
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) )
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache)) num_hidden_layers = text_config.num_hidden_layers
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers) cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape) self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
# Check 2: The outputs must be similar to the case with dynamic cache # Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation) self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
@require_optimum_quanto @require_optimum_quanto
@pytest.mark.generate @pytest.mark.generate