Iterative generation using Input embeds and past_key_values (#35890)
* Iterative generation using input embeds
* ruff fix
* Added Testcase
* Updated comment
* ♻️ Refactored testcase
* Skip test for these models
* Continue generation using input embeds and cache
* Skip generate_continue_from_embeds test
* Refactor `prepare_input_for_generation` func
* Continue generation using input embeds and cache
* Modular changes fix
* Overwrite 'prepare_inputs_for_generation' function
This commit is contained in:
@@ -1857,6 +1857,83 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
"""Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call."""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
if "token_type_ids" in inputs_dict:
|
||||
del inputs_dict["token_type_ids"]
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder")
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
|
||||
|
||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs_dict)
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"]
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
model.config.is_decoder = True
|
||||
model.generation_config.use_cache = True
|
||||
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values.
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs = model.generate(inputs_embeds=input_embeds, max_new_tokens=4, **generation_kwargs)
|
||||
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens)
|
||||
initial_output = model.generate(inputs_embeds=input_embeds, max_new_tokens=3, **generation_kwargs)
|
||||
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
|
||||
cached_output = model.generate(
|
||||
inputs_embeds=continued_embeds,
|
||||
max_new_tokens=1,
|
||||
past_key_values=initial_output.past_key_values,
|
||||
**generation_kwargs,
|
||||
)
|
||||
|
||||
# Combine the (3 + 1) generated tokens and verify it matches with full generation.
|
||||
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
|
||||
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
|
||||
# The two sets of past kv should be equal to each other
|
||||
for layer_idx in range(len(cached_output.past_key_values)):
|
||||
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
cached_output.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||
@require_torch_gpu
|
||||
@pytest.mark.generate
|
||||
|
||||
Reference in New Issue
Block a user