fix: providing a tensor to cache_position in model.generate kwargs always crashes because of boolean test (#39300)
* fix: cache_position: RuntimeError: Boolean value of Tensor with more than one value is ambiguous * test cache_position * move test * propagate changes --------- Co-authored-by: Masataro Asai <guicho2.71828@gmail.com>
This commit is contained in:
@@ -168,34 +168,6 @@ class GenerationTesterMixin:
|
||||
|
||||
return config, filtered_inputs_dict
|
||||
|
||||
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
|
||||
"""
|
||||
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
|
||||
the following situations:
|
||||
1. The sequences are the same
|
||||
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
|
||||
"""
|
||||
# scores doesn't include data regarding decoder input tokens
|
||||
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
||||
output_matches = output_1.sequences == output_2.sequences
|
||||
has_matching_outputs = output_matches.all()
|
||||
has_matching_scores = None
|
||||
if not has_matching_outputs:
|
||||
for batch_idx in range(output_1.sequences.shape[0]):
|
||||
batch_matches = output_matches[batch_idx]
|
||||
if batch_matches.all():
|
||||
continue
|
||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||
first_mismatch_idx -= decoder_input_length
|
||||
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
||||
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
||||
has_matching_scores = torch.allclose(
|
||||
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
||||
)
|
||||
if not has_matching_scores:
|
||||
break
|
||||
self.assertTrue(has_matching_outputs or has_matching_scores)
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {
|
||||
"bad_words_ids": [[1, 0]],
|
||||
@@ -1094,7 +1066,7 @@ class GenerationTesterMixin:
|
||||
|
||||
low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True)
|
||||
high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False)
|
||||
self._check_similar_generate_outputs(low_output, high_output)
|
||||
self.assertTrue(has_similar_generate_outputs(low_output, high_output))
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@pytest.mark.generate
|
||||
@@ -1176,7 +1148,7 @@ class GenerationTesterMixin:
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||
self.assertTrue(has_similar_generate_outputs(output_greedy, output_assisted))
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||
|
||||
@@ -1259,7 +1231,7 @@ class GenerationTesterMixin:
|
||||
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
|
||||
self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup))
|
||||
for output in (output_greedy, output_prompt_lookup):
|
||||
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||
|
||||
@@ -1745,7 +1717,7 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
if not has_complex_embeds_computation:
|
||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
||||
self.assertTrue(has_similar_generate_outputs(outputs_from_ids, outputs_from_embeds))
|
||||
|
||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||
# be the same
|
||||
@@ -1754,7 +1726,7 @@ class GenerationTesterMixin:
|
||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
||||
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
|
||||
self.assertTrue(has_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds))
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
@@ -1896,7 +1868,7 @@ class GenerationTesterMixin:
|
||||
outputs_cached.scores = full_cached_scores
|
||||
|
||||
# The two sets of generated text and past kv should be equal to each other
|
||||
self._check_similar_generate_outputs(outputs, outputs_cached)
|
||||
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
|
||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
@@ -1923,7 +1895,7 @@ class GenerationTesterMixin:
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
self.skipTest(reason="This model is encoder-decoder")
|
||||
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
|
||||
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
|
||||
# but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
@@ -2038,7 +2010,7 @@ class GenerationTesterMixin:
|
||||
|
||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
||||
self.assertTrue(has_similar_generate_outputs(dynamic_cache_generation, static_cache_generation))
|
||||
|
||||
@require_optimum_quanto
|
||||
@pytest.mark.generate
|
||||
@@ -2192,7 +2164,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
self.assertTrue(has_similar_generate_outputs(dynamic_result, compiled_result))
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
@@ -2381,7 +2353,7 @@ class GenerationTesterMixin:
|
||||
del model_attn
|
||||
gc.collect()
|
||||
|
||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
||||
self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_sdpa
|
||||
@@ -5087,6 +5059,97 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
assert value == "success"
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_custom_cache_position(self):
|
||||
"""
|
||||
Regression test for #39261. Tests that we can continue generating from past key values, returned from a
|
||||
previous `generate` call, without the tokens that correspond to the cached part. This is achieved by passing
|
||||
manually creating `cache_position` -- this tests that it is piped correctly.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
|
||||
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
|
||||
generate_kwargs = {
|
||||
"use_cache": True,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
|
||||
# Traditional way to continue generating text using kv cache
|
||||
# output2
|
||||
# /~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||
# input2
|
||||
# /~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||
# output1
|
||||
# /~~~~~~~~~~~~~~~~\
|
||||
# input1
|
||||
# /~~~~~~\
|
||||
# IIIIIIIIOOOOOOOOOOIIIIIIIIOOOOOOOOOOOOOOOOOO
|
||||
inputs_1a = model_inputs
|
||||
outputs_1a = model.generate(**inputs_1a, **generate_kwargs, max_new_tokens=2)
|
||||
inputs_2a = {**model_inputs}
|
||||
inputs_2a["input_ids"] = torch.cat((outputs_1a.sequences, model_inputs["input_ids"]), dim=1)
|
||||
inputs_2a["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs_1a["attention_mask"],
|
||||
(0, inputs_2a["input_ids"].shape[1] - inputs_1a["input_ids"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
inputs_2a["past_key_values"] = outputs_1a.past_key_values
|
||||
outputs_2a = model.generate(**inputs_2a, **generate_kwargs, max_new_tokens=2)
|
||||
# Keep only the part of the output related to the second output + last token from the first output, for future
|
||||
# comparison
|
||||
traditional_outputs = copy.deepcopy(outputs_2a)
|
||||
traditional_outputs.sequences = traditional_outputs.sequences[:, outputs_1a.sequences.shape[1] - 1 :]
|
||||
|
||||
# Continue generating text using kv cache, but without providing the cached part of the input in the input_ids.
|
||||
# cache_position
|
||||
# /~~~~~~~\
|
||||
# inputs2["attention_mask"]
|
||||
# /~~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||
# output1 output2
|
||||
# /~~~~~~~~~~~~~~~~\/~~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||
# input1 input2
|
||||
# /~~~~~~\ /~~~~~~~\
|
||||
# IIIIIIIIOOOOOOOOOOIIIIIIIIIOOOOOOOOOOOOOOOOOO
|
||||
#
|
||||
inputs_1b = model_inputs
|
||||
outputs_1b = model.generate(**inputs_1b, **generate_kwargs, max_new_tokens=2)
|
||||
inputs_2b = {**model_inputs}
|
||||
# The last output token isn't cached, so it needs to be included in the new input
|
||||
inputs_2b["input_ids"] = torch.cat((outputs_1b.sequences[:, -1:], model_inputs["input_ids"]), dim=1)
|
||||
inputs_2b["attention_mask"] = torch.nn.functional.pad(
|
||||
inputs_1b["attention_mask"],
|
||||
(0, outputs_1b.sequences.shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
inputs_2b["past_key_values"] = outputs_1b.past_key_values
|
||||
cache_length_1b = outputs_1b.past_key_values[0][0].shape[-2]
|
||||
inputs_2b["cache_position"] = torch.arange(
|
||||
cache_length_1b,
|
||||
cache_length_1b + inputs_2b["input_ids"].shape[1],
|
||||
dtype=torch.int64,
|
||||
device=model.device,
|
||||
)
|
||||
outputs_2b = model.generate(**inputs_2b, **generate_kwargs, max_new_tokens=2)
|
||||
incremental_outputs = outputs_2b
|
||||
|
||||
# The two sets of generated text and past kv should be equal to each other
|
||||
self.assertTrue(has_similar_generate_outputs(traditional_outputs, incremental_outputs))
|
||||
for layer_idx in range(len(traditional_outputs.past_key_values)):
|
||||
for kv_idx in range(len(traditional_outputs.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
traditional_outputs.past_key_values[layer_idx][kv_idx],
|
||||
incremental_outputs.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
@@ -5281,3 +5344,41 @@ class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
|
||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||
else:
|
||||
self.assert_no_sklearn()
|
||||
|
||||
|
||||
def has_similar_generate_outputs(output_1, output_2, atol=1e-5, rtol=1e-5) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether a pair of generate outputs are similar. Two `generate` call outputs are
|
||||
considered similar in the following situations:
|
||||
1. The sequences are the same
|
||||
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
|
||||
|
||||
Args:
|
||||
output_1 (`GenerateOutput`): The first `generate` call output.
|
||||
output_2 (`GenerateOutput`): The second `generate` call output.
|
||||
atol (`float`, *optional*, defaults to 1e-5): The absolute tolerance for the scores.
|
||||
rtol (`float`, *optional*, defaults to 1e-5): The relative tolerance for the scores.
|
||||
|
||||
Returns:
|
||||
A boolean indicating whether the two generate outputs are similar.
|
||||
"""
|
||||
# scores doesn't include data regarding decoder input tokens
|
||||
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
||||
output_matches = output_1.sequences == output_2.sequences
|
||||
has_matching_outputs = output_matches.all()
|
||||
has_matching_scores = None
|
||||
if not has_matching_outputs:
|
||||
for batch_idx in range(output_1.sequences.shape[0]):
|
||||
batch_matches = output_matches[batch_idx]
|
||||
if batch_matches.all():
|
||||
continue
|
||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||
first_mismatch_idx -= decoder_input_length
|
||||
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
||||
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
||||
has_matching_scores = torch.allclose(
|
||||
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
||||
)
|
||||
if not has_matching_scores:
|
||||
break
|
||||
return has_matching_outputs or has_matching_scores
|
||||
|
||||
Reference in New Issue
Block a user