fix static cache data type miss-match (#34799)
* fix gptj data type missmatch Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add low precision static cache tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix low-precision static cache tests * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * avoid config change Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * change data type convert in cache copy Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * cast key value after k v out Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -1217,6 +1217,8 @@ class StaticCache(Cache):
|
|||||||
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
|
key_states = key_states.to(k_out.dtype)
|
||||||
|
value_states = value_states.to(v_out.dtype)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
k_out.copy_(key_states)
|
k_out.copy_(key_states)
|
||||||
|
|||||||
@@ -1901,36 +1901,41 @@ class GenerationTesterMixin:
|
|||||||
seq_length = main_input.shape[-1]
|
seq_length = main_input.shape[-1]
|
||||||
max_new_tokens = 20
|
max_new_tokens = 20
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
for dtype in (torch.float32, torch.float16):
|
||||||
generation_kwargs = {
|
model = model_class(config).to(torch_device).to(dtype).eval()
|
||||||
"max_new_tokens": max_new_tokens,
|
generation_kwargs = {
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
"max_new_tokens": max_new_tokens,
|
||||||
"output_scores": True,
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
"use_cache": True,
|
"output_scores": True,
|
||||||
}
|
"use_cache": True,
|
||||||
|
}
|
||||||
|
|
||||||
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
|
static_cache_generation = model.generate(
|
||||||
|
**generation_kwargs, **inputs_dict, cache_implementation="static"
|
||||||
|
)
|
||||||
|
|
||||||
# Check 1: The cache shapes must match the expected shapes
|
# Check 1: The cache shapes must match the expected shapes
|
||||||
max_cache_len = seq_length + max_new_tokens
|
max_cache_len = seq_length + max_new_tokens
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
text_config = config.text_config if hasattr(config, "text_config") else config
|
||||||
head_dim = (
|
head_dim = (
|
||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
text_config.head_dim
|
||||||
)
|
if hasattr(text_config, "head_dim")
|
||||||
num_key_value_heads = (
|
else text_config.hidden_size // text_config.num_attention_heads
|
||||||
config.num_attention_heads
|
)
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
num_key_value_heads = (
|
||||||
else config.num_key_value_heads
|
text_config.num_attention_heads
|
||||||
)
|
if getattr(text_config, "num_key_value_heads", None) is None
|
||||||
num_hidden_layers = config.num_hidden_layers
|
else text_config.num_key_value_heads
|
||||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
)
|
||||||
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
num_hidden_layers = text_config.num_hidden_layers
|
||||||
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||||
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||||
|
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||||
|
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||||
|
|
||||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
||||||
|
|
||||||
@require_optimum_quanto
|
@require_optimum_quanto
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
|
|||||||
Reference in New Issue
Block a user