Generate: return past_key_values (#25086)
This commit is contained in:
@@ -1829,6 +1829,85 @@ class GenerationTesterMixin:
|
||||
outputs_from_embeds_wo_ids[:, 1:].tolist(),
|
||||
)
|
||||
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# won't fix: old models with unique inputs/caches/others
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||
return
|
||||
# may fix in the future: needs modeling or test input preparation fixes for compatibility
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
return
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
|
||||
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
|
||||
# continuation would force it to generate beyond an EOS token)
|
||||
# 3. ignore `token_type_ids` for simplicity
|
||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||
# active by default on some models
|
||||
config.use_cache = True
|
||||
if "token_type_ids" in inputs:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
model.generation_config.forced_eos_token_id = None
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
outputs = model(**inputs)
|
||||
if "past_key_values" not in outputs:
|
||||
return
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
|
||||
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
|
||||
|
||||
# Continue from the tokens generated above, preparing the inputs accordingly
|
||||
inputs["past_key_values"] = outputs_cached.past_key_values
|
||||
new_attention_len = outputs_cached.sequences.shape[-1]
|
||||
if config.is_encoder_decoder:
|
||||
inputs["decoder_input_ids"] = outputs_cached.sequences
|
||||
if "decoder_attention_mask" in inputs:
|
||||
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["decoder_attention_mask"],
|
||||
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
else:
|
||||
inputs["input_ids"] = outputs_cached.sequences
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["attention_mask"],
|
||||
(0, new_attention_len - inputs["attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
|
||||
|
||||
# The two sets of generated text and past kv should be equal to each other
|
||||
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
|
||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
outputs_cached.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
@@ -1894,6 +1973,24 @@ class GenerationTesterMixin:
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Past Key Value States -- two notes here:
|
||||
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
if use_cache and has_standard_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
num_sequences_in_output,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
)
|
||||
|
||||
def _check_scores(self, batch_size, scores, length, config):
|
||||
expected_shape = (batch_size, config.vocab_size)
|
||||
self.assertIsInstance(scores, tuple)
|
||||
@@ -1959,6 +2056,30 @@ class GenerationTesterMixin:
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||
self.assertIsInstance(past_key_values, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
||||
[True] * len(past_key_values),
|
||||
)
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||
seq_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
)
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||
# set to same device. we don't care what device.
|
||||
|
||||
Reference in New Issue
Block a user