[Phi-3] Bug on stale kv cache (#33129)
* fix long seq bug * fixed format * fixed fn copy inconsistency * fix long seq bug * fixed format * fixed fn copy inconsistency * Addressed comments * added a unit test * fixed cache position * Added a warning msg to the forward fn * fixed test case
This commit is contained in:
@@ -442,6 +442,47 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
@parameterized.expand([("longrope",)])
|
||||
def test_model_rope_scaling_short_long_factor(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
n_factors = config.hidden_size // config.num_key_value_heads // 2
|
||||
config.rope_scaling = {
|
||||
"type": scaling_type,
|
||||
"short_factor": [3.0 for _ in range(n_factors)],
|
||||
"long_factor": [5.0 for _ in range(n_factors)],
|
||||
}
|
||||
input_tensor = ids_tensor([1, 4090], config.vocab_size)
|
||||
model = Phi3ForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
generation_args_short = {
|
||||
"max_length": config.original_max_position_embeddings,
|
||||
"temperature": 0.0,
|
||||
"use_cache": True,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_with_short_factor = model.generate(input_tensor, **generation_args_short)
|
||||
keys_with_short_factor = output_with_short_factor.past_key_values[0][0]
|
||||
generation_args_long = {
|
||||
"max_length": config.original_max_position_embeddings + 5,
|
||||
"temperature": 0.0,
|
||||
"use_cache": True,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_logits": True,
|
||||
}
|
||||
output_with_long_factor = model.generate(input_tensor, **generation_args_long)
|
||||
keys_with_long_factor = output_with_long_factor.past_key_values[0][0]
|
||||
last_token_logits = output_with_long_factor.logits[-1][-1]
|
||||
regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1]
|
||||
keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :]
|
||||
|
||||
# KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position
|
||||
self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2))
|
||||
# Last token generated using long factor
|
||||
self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user