[Phi-3] Bug on stale kv cache (#33129)
* fix long seq bug * fixed format * fixed fn copy inconsistency * fix long seq bug * fixed format * fixed fn copy inconsistency * Addressed comments * added a unit test * fixed cache position * Added a warning msg to the forward fn * fixed test case
This commit is contained in:
@@ -257,7 +257,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
seq_len = torch.max(position_ids) + 1
|
seq_len = seq_len or torch.max(position_ids) + 1
|
||||||
if seq_len > self.original_max_position_embeddings:
|
if seq_len > self.original_max_position_embeddings:
|
||||||
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
||||||
else:
|
else:
|
||||||
@@ -1239,6 +1239,15 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
|||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
|
||||||
```"""
|
```"""
|
||||||
|
if (
|
||||||
|
use_cache
|
||||||
|
and self.config.rope_scaling
|
||||||
|
and cache_position is not None
|
||||||
|
and cache_position[0] == self.config.original_max_position_embeddings
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
|
||||||
|
)
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -1295,7 +1304,6 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -1308,6 +1316,17 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
|
|||||||
num_logits_to_keep=None,
|
num_logits_to_keep=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
||||||
|
# It will cause downside of slower at this single token position, however, better than current failure.
|
||||||
|
if (
|
||||||
|
past_key_values
|
||||||
|
and self.config.rope_scaling
|
||||||
|
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
|
||||||
|
):
|
||||||
|
past_length = cache_position[0]
|
||||||
|
if past_length <= self.config.original_max_position_embeddings:
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
|||||||
@@ -442,6 +442,47 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
|
@parameterized.expand([("longrope",)])
|
||||||
|
def test_model_rope_scaling_short_long_factor(self, scaling_type):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
n_factors = config.hidden_size // config.num_key_value_heads // 2
|
||||||
|
config.rope_scaling = {
|
||||||
|
"type": scaling_type,
|
||||||
|
"short_factor": [3.0 for _ in range(n_factors)],
|
||||||
|
"long_factor": [5.0 for _ in range(n_factors)],
|
||||||
|
}
|
||||||
|
input_tensor = ids_tensor([1, 4090], config.vocab_size)
|
||||||
|
model = Phi3ForCausalLM(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
generation_args_short = {
|
||||||
|
"max_length": config.original_max_position_embeddings,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"use_cache": True,
|
||||||
|
"do_sample": False,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
}
|
||||||
|
output_with_short_factor = model.generate(input_tensor, **generation_args_short)
|
||||||
|
keys_with_short_factor = output_with_short_factor.past_key_values[0][0]
|
||||||
|
generation_args_long = {
|
||||||
|
"max_length": config.original_max_position_embeddings + 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"use_cache": True,
|
||||||
|
"do_sample": False,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"output_logits": True,
|
||||||
|
}
|
||||||
|
output_with_long_factor = model.generate(input_tensor, **generation_args_long)
|
||||||
|
keys_with_long_factor = output_with_long_factor.past_key_values[0][0]
|
||||||
|
last_token_logits = output_with_long_factor.logits[-1][-1]
|
||||||
|
regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1]
|
||||||
|
keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :]
|
||||||
|
|
||||||
|
# KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position
|
||||||
|
self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2))
|
||||||
|
# Last token generated using long factor
|
||||||
|
self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2))
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user