[Phi-3] Bug on stale kv cache (#33129)
* fix long seq bug * fixed format * fixed fn copy inconsistency * fix long seq bug * fixed format * fixed fn copy inconsistency * Addressed comments * added a unit test * fixed cache position * Added a warning msg to the forward fn * fixed test case
This commit is contained in:
@@ -257,7 +257,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
seq_len = seq_len or torch.max(position_ids) + 1
|
||||
if seq_len > self.original_max_position_embeddings:
|
||||
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
||||
else:
|
||||
@@ -1239,6 +1239,15 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
||||
```"""
|
||||
if (
|
||||
use_cache
|
||||
and self.config.rope_scaling
|
||||
and cache_position is not None
|
||||
and cache_position[0] == self.config.original_max_position_embeddings
|
||||
):
|
||||
logger.warning(
|
||||
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -1295,7 +1304,6 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -1308,6 +1316,17 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
||||
# It will cause downside of slower at this single token position, however, better than current failure.
|
||||
if (
|
||||
past_key_values
|
||||
and self.config.rope_scaling
|
||||
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
|
||||
):
|
||||
past_length = cache_position[0]
|
||||
if past_length <= self.config.original_max_position_embeddings:
|
||||
past_key_values = None
|
||||
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
|
||||
@@ -442,6 +442,47 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
@parameterized.expand([("longrope",)])
|
||||
def test_model_rope_scaling_short_long_factor(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
n_factors = config.hidden_size // config.num_key_value_heads // 2
|
||||
config.rope_scaling = {
|
||||
"type": scaling_type,
|
||||
"short_factor": [3.0 for _ in range(n_factors)],
|
||||
"long_factor": [5.0 for _ in range(n_factors)],
|
||||
}
|
||||
input_tensor = ids_tensor([1, 4090], config.vocab_size)
|
||||
model = Phi3ForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
generation_args_short = {
|
||||
"max_length": config.original_max_position_embeddings,
|
||||
"temperature": 0.0,
|
||||
"use_cache": True,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_with_short_factor = model.generate(input_tensor, **generation_args_short)
|
||||
keys_with_short_factor = output_with_short_factor.past_key_values[0][0]
|
||||
generation_args_long = {
|
||||
"max_length": config.original_max_position_embeddings + 5,
|
||||
"temperature": 0.0,
|
||||
"use_cache": True,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_logits": True,
|
||||
}
|
||||
output_with_long_factor = model.generate(input_tensor, **generation_args_long)
|
||||
keys_with_long_factor = output_with_long_factor.past_key_values[0][0]
|
||||
last_token_logits = output_with_long_factor.logits[-1][-1]
|
||||
regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1]
|
||||
keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :]
|
||||
|
||||
# KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position
|
||||
self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2))
|
||||
# Last token generated using long factor
|
||||
self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user