diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index acd53a20b7..36210b3989 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1800,7 +1800,7 @@ class GenerationMixin(ContinuousMixin): def _get_initial_cache_position(self, seq_length, device, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` - if "cache_position" in model_kwargs and model_kwargs["cache_position"]: + if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: return model_kwargs if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d73bfa2274..0f7966a9c9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -168,34 +168,6 @@ class GenerationTesterMixin: return config, filtered_inputs_dict - def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5): - """ - Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in - the following situations: - 1. The sequences are the same - 2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical - """ - # scores doesn't include data regarding decoder input tokens - decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores) - output_matches = output_1.sequences == output_2.sequences - has_matching_outputs = output_matches.all() - has_matching_scores = None - if not has_matching_outputs: - for batch_idx in range(output_1.sequences.shape[0]): - batch_matches = output_matches[batch_idx] - if batch_matches.all(): - continue - first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False - first_mismatch_idx -= decoder_input_length - output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx] - output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx] - has_matching_scores = torch.allclose( - output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol - ) - if not has_matching_scores: - break - self.assertTrue(has_matching_outputs or has_matching_scores) - def _get_logits_processor_kwargs(self, do_sample=False, config=None): logits_processor_kwargs = { "bad_words_ids": [[1, 0]], @@ -1094,7 +1066,7 @@ class GenerationTesterMixin: low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True) high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False) - self._check_similar_generate_outputs(low_output, high_output) + self.assertTrue(has_similar_generate_outputs(low_output, high_output)) @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate @@ -1176,7 +1148,7 @@ class GenerationTesterMixin: output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) # The two outputs must match and their shape must be as expected - self._check_similar_generate_outputs(output_greedy, output_assisted) + self.assertTrue(has_similar_generate_outputs(output_greedy, output_assisted)) for output in (output_greedy, output_assisted): self._check_generate_outputs(output, model.config, use_cache=True) @@ -1259,7 +1231,7 @@ class GenerationTesterMixin: output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict) # The two outputs must match and their shape must be as expected - self._check_similar_generate_outputs(output_greedy, output_prompt_lookup) + self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup)) for output in (output_greedy, output_prompt_lookup): self._check_generate_outputs(output, model.config, use_cache=True) @@ -1745,7 +1717,7 @@ class GenerationTesterMixin: input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict ) if not has_complex_embeds_computation: - self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) + self.assertTrue(has_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)) # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will # be the same @@ -1754,7 +1726,7 @@ class GenerationTesterMixin: inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict ) outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :] - self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds) + self.assertTrue(has_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)) @pytest.mark.generate def test_generate_from_inputs_embeds_with_static_cache(self): @@ -1896,7 +1868,7 @@ class GenerationTesterMixin: outputs_cached.scores = full_cached_scores # The two sets of generated text and past kv should be equal to each other - self._check_similar_generate_outputs(outputs, outputs_cached) + self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached)) for layer_idx in range(len(outputs_cached.past_key_values)): for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): self.assertTrue( @@ -1923,7 +1895,7 @@ class GenerationTesterMixin: if config.get_text_config(decoder=True).is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder") # TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`, - # but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern + # but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -2038,7 +2010,7 @@ class GenerationTesterMixin: # Check 2: The outputs must be similar to the case with dynamic cache dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict) - self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation) + self.assertTrue(has_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)) @require_optimum_quanto @pytest.mark.generate @@ -2192,7 +2164,7 @@ class GenerationTesterMixin: ) for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): - self._check_similar_generate_outputs(dynamic_result, compiled_result) + self.assertTrue(has_similar_generate_outputs(dynamic_result, compiled_result)) @pytest.mark.generate def test_generate_compilation_all_outputs(self): @@ -2381,7 +2353,7 @@ class GenerationTesterMixin: del model_attn gc.collect() - self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3) + self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)) @pytest.mark.generate @require_torch_sdpa @@ -5087,6 +5059,97 @@ class GenerationIntegrationTests(unittest.TestCase): ) assert value == "success" + @pytest.mark.generate + def test_generate_custom_cache_position(self): + """ + Regression test for #39261. Tests that we can continue generating from past key values, returned from a + previous `generate` call, without the tokens that correspond to the cached part. This is achieved by passing + manually creating `cache_position` -- this tests that it is piped correctly. + """ + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + + model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device) + generate_kwargs = { + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + } + + # Traditional way to continue generating text using kv cache + # output2 + # /~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ + # input2 + # /~~~~~~~~~~~~~~~~~~~~~~~~\ + # output1 + # /~~~~~~~~~~~~~~~~\ + # input1 + # /~~~~~~\ + # IIIIIIIIOOOOOOOOOOIIIIIIIIOOOOOOOOOOOOOOOOOO + inputs_1a = model_inputs + outputs_1a = model.generate(**inputs_1a, **generate_kwargs, max_new_tokens=2) + inputs_2a = {**model_inputs} + inputs_2a["input_ids"] = torch.cat((outputs_1a.sequences, model_inputs["input_ids"]), dim=1) + inputs_2a["attention_mask"] = torch.nn.functional.pad( + inputs_1a["attention_mask"], + (0, inputs_2a["input_ids"].shape[1] - inputs_1a["input_ids"].shape[1]), + mode="constant", + value=1, + ) + inputs_2a["past_key_values"] = outputs_1a.past_key_values + outputs_2a = model.generate(**inputs_2a, **generate_kwargs, max_new_tokens=2) + # Keep only the part of the output related to the second output + last token from the first output, for future + # comparison + traditional_outputs = copy.deepcopy(outputs_2a) + traditional_outputs.sequences = traditional_outputs.sequences[:, outputs_1a.sequences.shape[1] - 1 :] + + # Continue generating text using kv cache, but without providing the cached part of the input in the input_ids. + # cache_position + # /~~~~~~~\ + # inputs2["attention_mask"] + # /~~~~~~~~~~~~~~~~~~~~~~~~~\ + # output1 output2 + # /~~~~~~~~~~~~~~~~\/~~~~~~~~~~~~~~~~~~~~~~~~~\ + # input1 input2 + # /~~~~~~\ /~~~~~~~\ + # IIIIIIIIOOOOOOOOOOIIIIIIIIIOOOOOOOOOOOOOOOOOO + # + inputs_1b = model_inputs + outputs_1b = model.generate(**inputs_1b, **generate_kwargs, max_new_tokens=2) + inputs_2b = {**model_inputs} + # The last output token isn't cached, so it needs to be included in the new input + inputs_2b["input_ids"] = torch.cat((outputs_1b.sequences[:, -1:], model_inputs["input_ids"]), dim=1) + inputs_2b["attention_mask"] = torch.nn.functional.pad( + inputs_1b["attention_mask"], + (0, outputs_1b.sequences.shape[1]), + mode="constant", + value=1, + ) + inputs_2b["past_key_values"] = outputs_1b.past_key_values + cache_length_1b = outputs_1b.past_key_values[0][0].shape[-2] + inputs_2b["cache_position"] = torch.arange( + cache_length_1b, + cache_length_1b + inputs_2b["input_ids"].shape[1], + dtype=torch.int64, + device=model.device, + ) + outputs_2b = model.generate(**inputs_2b, **generate_kwargs, max_new_tokens=2) + incremental_outputs = outputs_2b + + # The two sets of generated text and past kv should be equal to each other + self.assertTrue(has_similar_generate_outputs(traditional_outputs, incremental_outputs)) + for layer_idx in range(len(traditional_outputs.past_key_values)): + for kv_idx in range(len(traditional_outputs.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + traditional_outputs.past_key_values[layer_idx][kv_idx], + incremental_outputs.past_key_values[layer_idx][kv_idx], + ) + ) + @require_torch class TokenHealingTestCase(unittest.TestCase): @@ -5281,3 +5344,41 @@ class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase): self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4) else: self.assert_no_sklearn() + + +def has_similar_generate_outputs(output_1, output_2, atol=1e-5, rtol=1e-5) -> bool: + """ + Returns a boolean indicating whether a pair of generate outputs are similar. Two `generate` call outputs are + considered similar in the following situations: + 1. The sequences are the same + 2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical + + Args: + output_1 (`GenerateOutput`): The first `generate` call output. + output_2 (`GenerateOutput`): The second `generate` call output. + atol (`float`, *optional*, defaults to 1e-5): The absolute tolerance for the scores. + rtol (`float`, *optional*, defaults to 1e-5): The relative tolerance for the scores. + + Returns: + A boolean indicating whether the two generate outputs are similar. + """ + # scores doesn't include data regarding decoder input tokens + decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores) + output_matches = output_1.sequences == output_2.sequences + has_matching_outputs = output_matches.all() + has_matching_scores = None + if not has_matching_outputs: + for batch_idx in range(output_1.sequences.shape[0]): + batch_matches = output_matches[batch_idx] + if batch_matches.all(): + continue + first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False + first_mismatch_idx -= decoder_input_length + output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx] + output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx] + has_matching_scores = torch.allclose( + output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol + ) + if not has_matching_scores: + break + return has_matching_outputs or has_matching_scores diff --git a/tests/models/dia/test_modeling_dia.py b/tests/models/dia/test_modeling_dia.py index c6d1547afb..ecd4cf3a42 100644 --- a/tests/models/dia/test_modeling_dia.py +++ b/tests/models/dia/test_modeling_dia.py @@ -33,7 +33,7 @@ from transformers.testing_utils import ( from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available from transformers.utils.import_utils import is_datasets_available -from ...generation.test_utils import GenerationTesterMixin +from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -512,7 +512,7 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, outputs_cached.scores = full_cached_scores # The two sets of generated text and past kv should be equal to each other - self._check_similar_generate_outputs(outputs, outputs_cached) + self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached)) for layer_idx in range(len(outputs_cached.past_key_values)): for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): self.assertTrue( diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 6e58c13f03..28aaa80254 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -40,7 +40,7 @@ from transformers.utils import ( is_vision_available, ) -from ...generation.test_utils import GenerationTesterMixin +from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, @@ -650,7 +650,7 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi outputs_from_embeds = model.generate( input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict ) - self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds) + self.assertTrue(has_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)) # input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will # be the same @@ -658,7 +658,7 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict ) outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :] - self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds) + self.assertTrue(has_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)) # We will verify our results on an image of cute cats diff --git a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py index 35898887fa..cb58eea261 100644 --- a/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py +++ b/tests/models/kyutai_speech_to_text/test_modeling_kyutai_speech_to_text.py @@ -38,7 +38,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin +from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, @@ -527,7 +527,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel outputs_cached.scores = full_cached_scores # The two sets of generated text and past kv should be equal to each other - self._check_similar_generate_outputs(outputs, outputs_cached) + self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached)) for layer_idx in range(len(outputs_cached.past_key_values)): for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): self.assertTrue( @@ -613,7 +613,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel del model_attn gc.collect() - self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3) + self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)) @require_torch diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index 100787ece9..95279fae5b 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -29,7 +29,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin +from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -1196,7 +1196,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi outputs_cached.scores = full_cached_scores # The two sets of generated text and past kv should be equal to each other - self._check_similar_generate_outputs(outputs, outputs_cached) + self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached)) for layer_idx in range(len(outputs_cached.past_key_values)): for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): self.assertTrue(