Iterative generation using Input embeds and past_key_values (#35890)
* Iterative generation using input embeds
* ruff fix
* Added Testcase
* Updated comment
* ♻️ Refactored testcase
* Skip test for these models
* Continue generation using input embeds and cache
* Skip generate_continue_from_embeds test
* Refactor `prepare_input_for_generation` func
* Continue generation using input embeds and cache
* Modular changes fix
* Overwrite 'prepare_inputs_for_generation' function
This commit is contained in:
@@ -334,6 +334,10 @@ class ClvpDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@unittest.skip(reason="Clvp `prepare_inputs_for_generation` function doesn't have cache position.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
|
||||
class ClvpModelForConditionalGenerationTester:
|
||||
def __init__(self, parent, is_training=False):
|
||||
|
||||
@@ -131,6 +131,10 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
|
||||
@@ -325,6 +325,10 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip(reason="Fuyu `prepare_inputs_for_generation` function doesn't have cache position.")
|
||||
def test_generate_continue_from_inputs_embeds():
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
@@ -146,6 +146,10 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
|
||||
@@ -450,6 +450,10 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BigCodeGPT has a non-standard KV cache format and breaks this test.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_gpt_bigcode_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)
|
||||
|
||||
@@ -755,6 +755,65 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
"""Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`."""
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(inputs)
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
model.generation_config.use_cache = True
|
||||
|
||||
input_ids = inputs.pop("input_ids")
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
inputs["inputs_embeds"] = input_embeds
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, max_new_tokens=4, **generation_kwargs)
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||
initial_output = model.generate(**inputs, max_new_tokens=3, **generation_kwargs)
|
||||
inputs["past_key_values"] = initial_output.past_key_values
|
||||
|
||||
new_attention_len = input_ids.shape[1] + initial_output.sequences.shape[-1]
|
||||
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
|
||||
inputs["inputs_embeds"] = continued_embeds
|
||||
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["attention_mask"],
|
||||
(0, new_attention_len - inputs["attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
if "image_attention_mask" in inputs:
|
||||
inputs["image_attention_mask"] = inputs["image_attention_mask"][..., -1:, :]
|
||||
|
||||
cached_output = model.generate(**inputs, max_new_tokens=1, **generation_kwargs)
|
||||
|
||||
# Verify that the combined outputs match the full generation.
|
||||
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
|
||||
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
|
||||
for layer_idx in range(len(cached_output.past_key_values)):
|
||||
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
cached_output.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
|
||||
@@ -358,6 +358,10 @@ class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@is_flaky(max_attempts=5, description="flaky on some models.")
|
||||
def test_save_load(self):
|
||||
super().test_save_load()
|
||||
@@ -824,6 +828,7 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||
)
|
||||
print(output_ids_generate)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@unittest.skip(reason="The audio encoder has no gradients.")
|
||||
@@ -919,6 +924,10 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@is_flaky(max_attempts=5, description="flaky on some models.")
|
||||
def test_save_load(self):
|
||||
super().test_save_load()
|
||||
|
||||
@@ -333,6 +333,10 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Zamba2 has hybrid cache.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user