fix: providing a tensor to cache_position in model.generate kwargs always crashes because of boolean test (#39300)

* fix: cache_position: RuntimeError: Boolean value of Tensor with more than one value is ambiguous

* test cache_position

* move test

* propagate changes

---------

Co-authored-by: Masataro Asai <guicho2.71828@gmail.com>
This commit is contained in:
Joao Gante
2025-07-30 18:30:28 +01:00
committed by GitHub
parent 9b3203f47b
commit 4f93cc9174
6 changed files with 150 additions and 49 deletions

View File

@@ -168,34 +168,6 @@ class GenerationTesterMixin:
return config, filtered_inputs_dict
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
"""
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
the following situations:
1. The sequences are the same
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
"""
# scores doesn't include data regarding decoder input tokens
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
output_matches = output_1.sequences == output_2.sequences
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
for batch_idx in range(output_1.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= decoder_input_length
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
)
if not has_matching_scores:
break
self.assertTrue(has_matching_outputs or has_matching_scores)
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
@@ -1094,7 +1066,7 @@ class GenerationTesterMixin:
low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True)
high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False)
self._check_similar_generate_outputs(low_output, high_output)
self.assertTrue(has_similar_generate_outputs(low_output, high_output))
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@@ -1176,7 +1148,7 @@ class GenerationTesterMixin:
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(output_greedy, output_assisted)
self.assertTrue(has_similar_generate_outputs(output_greedy, output_assisted))
for output in (output_greedy, output_assisted):
self._check_generate_outputs(output, model.config, use_cache=True)
@@ -1259,7 +1231,7 @@ class GenerationTesterMixin:
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup))
for output in (output_greedy, output_prompt_lookup):
self._check_generate_outputs(output, model.config, use_cache=True)
@@ -1745,7 +1717,7 @@ class GenerationTesterMixin:
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
if not has_complex_embeds_computation:
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
self.assertTrue(has_similar_generate_outputs(outputs_from_ids, outputs_from_embeds))
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
# be the same
@@ -1754,7 +1726,7 @@ class GenerationTesterMixin:
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
self.assertTrue(has_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds))
@pytest.mark.generate
def test_generate_from_inputs_embeds_with_static_cache(self):
@@ -1896,7 +1868,7 @@ class GenerationTesterMixin:
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
@@ -1923,7 +1895,7 @@ class GenerationTesterMixin:
if config.get_text_config(decoder=True).is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder")
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
# but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -2038,7 +2010,7 @@ class GenerationTesterMixin:
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
self.assertTrue(has_similar_generate_outputs(dynamic_cache_generation, static_cache_generation))
@require_optimum_quanto
@pytest.mark.generate
@@ -2192,7 +2164,7 @@ class GenerationTesterMixin:
)
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)
self.assertTrue(has_similar_generate_outputs(dynamic_result, compiled_result))
@pytest.mark.generate
def test_generate_compilation_all_outputs(self):
@@ -2381,7 +2353,7 @@ class GenerationTesterMixin:
del model_attn
gc.collect()
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3))
@pytest.mark.generate
@require_torch_sdpa
@@ -5087,6 +5059,97 @@ class GenerationIntegrationTests(unittest.TestCase):
)
assert value == "success"
@pytest.mark.generate
def test_generate_custom_cache_position(self):
"""
Regression test for #39261. Tests that we can continue generating from past key values, returned from a
previous `generate` call, without the tokens that correspond to the cached part. This is achieved by passing
manually creating `cache_position` -- this tests that it is piped correctly.
"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
generate_kwargs = {
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
# Traditional way to continue generating text using kv cache
# output2
# /~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
# input2
# /~~~~~~~~~~~~~~~~~~~~~~~~\
# output1
# /~~~~~~~~~~~~~~~~\
# input1
# /~~~~~~\
# IIIIIIIIOOOOOOOOOOIIIIIIIIOOOOOOOOOOOOOOOOOO
inputs_1a = model_inputs
outputs_1a = model.generate(**inputs_1a, **generate_kwargs, max_new_tokens=2)
inputs_2a = {**model_inputs}
inputs_2a["input_ids"] = torch.cat((outputs_1a.sequences, model_inputs["input_ids"]), dim=1)
inputs_2a["attention_mask"] = torch.nn.functional.pad(
inputs_1a["attention_mask"],
(0, inputs_2a["input_ids"].shape[1] - inputs_1a["input_ids"].shape[1]),
mode="constant",
value=1,
)
inputs_2a["past_key_values"] = outputs_1a.past_key_values
outputs_2a = model.generate(**inputs_2a, **generate_kwargs, max_new_tokens=2)
# Keep only the part of the output related to the second output + last token from the first output, for future
# comparison
traditional_outputs = copy.deepcopy(outputs_2a)
traditional_outputs.sequences = traditional_outputs.sequences[:, outputs_1a.sequences.shape[1] - 1 :]
# Continue generating text using kv cache, but without providing the cached part of the input in the input_ids.
# cache_position
# /~~~~~~~\
# inputs2["attention_mask"]
# /~~~~~~~~~~~~~~~~~~~~~~~~~\
# output1 output2
# /~~~~~~~~~~~~~~~~\/~~~~~~~~~~~~~~~~~~~~~~~~~\
# input1 input2
# /~~~~~~\ /~~~~~~~\
# IIIIIIIIOOOOOOOOOOIIIIIIIIIOOOOOOOOOOOOOOOOOO
#
inputs_1b = model_inputs
outputs_1b = model.generate(**inputs_1b, **generate_kwargs, max_new_tokens=2)
inputs_2b = {**model_inputs}
# The last output token isn't cached, so it needs to be included in the new input
inputs_2b["input_ids"] = torch.cat((outputs_1b.sequences[:, -1:], model_inputs["input_ids"]), dim=1)
inputs_2b["attention_mask"] = torch.nn.functional.pad(
inputs_1b["attention_mask"],
(0, outputs_1b.sequences.shape[1]),
mode="constant",
value=1,
)
inputs_2b["past_key_values"] = outputs_1b.past_key_values
cache_length_1b = outputs_1b.past_key_values[0][0].shape[-2]
inputs_2b["cache_position"] = torch.arange(
cache_length_1b,
cache_length_1b + inputs_2b["input_ids"].shape[1],
dtype=torch.int64,
device=model.device,
)
outputs_2b = model.generate(**inputs_2b, **generate_kwargs, max_new_tokens=2)
incremental_outputs = outputs_2b
# The two sets of generated text and past kv should be equal to each other
self.assertTrue(has_similar_generate_outputs(traditional_outputs, incremental_outputs))
for layer_idx in range(len(traditional_outputs.past_key_values)):
for kv_idx in range(len(traditional_outputs.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
traditional_outputs.past_key_values[layer_idx][kv_idx],
incremental_outputs.past_key_values[layer_idx][kv_idx],
)
)
@require_torch
class TokenHealingTestCase(unittest.TestCase):
@@ -5281,3 +5344,41 @@ class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
else:
self.assert_no_sklearn()
def has_similar_generate_outputs(output_1, output_2, atol=1e-5, rtol=1e-5) -> bool:
"""
Returns a boolean indicating whether a pair of generate outputs are similar. Two `generate` call outputs are
considered similar in the following situations:
1. The sequences are the same
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
Args:
output_1 (`GenerateOutput`): The first `generate` call output.
output_2 (`GenerateOutput`): The second `generate` call output.
atol (`float`, *optional*, defaults to 1e-5): The absolute tolerance for the scores.
rtol (`float`, *optional*, defaults to 1e-5): The relative tolerance for the scores.
Returns:
A boolean indicating whether the two generate outputs are similar.
"""
# scores doesn't include data regarding decoder input tokens
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
output_matches = output_1.sequences == output_2.sequences
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
for batch_idx in range(output_1.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= decoder_input_length
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
)
if not has_matching_scores:
break
return has_matching_outputs or has_matching_scores

View File

@@ -33,7 +33,7 @@ from transformers.testing_utils import (
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -512,7 +512,7 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(

View File

@@ -40,7 +40,7 @@ from transformers.utils import (
is_vision_available,
)
from ...generation.test_utils import GenerationTesterMixin
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
@@ -650,7 +650,7 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
outputs_from_embeds = model.generate(
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
self.assertTrue(has_similar_generate_outputs(outputs_from_ids, outputs_from_embeds))
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
# be the same
@@ -658,7 +658,7 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
self.assertTrue(has_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds))
# We will verify our results on an image of cute cats

View File

@@ -38,7 +38,7 @@ from transformers.testing_utils import (
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
@@ -527,7 +527,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
@@ -613,7 +613,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
del model_attn
gc.collect()
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3))
@require_torch

View File

@@ -29,7 +29,7 @@ from transformers.testing_utils import (
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -1196,7 +1196,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(