fix: providing a tensor to cache_position in model.generate kwargs always crashes because of boolean test (#39300)
* fix: cache_position: RuntimeError: Boolean value of Tensor with more than one value is ambiguous * test cache_position * move test * propagate changes --------- Co-authored-by: Masataro Asai <guicho2.71828@gmail.com>
This commit is contained in:
@@ -1800,7 +1800,7 @@ class GenerationMixin(ContinuousMixin):
|
|||||||
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
|
||||||
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||||
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
|
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
|
||||||
if "cache_position" in model_kwargs and model_kwargs["cache_position"]:
|
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
|
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
|
||||||
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
|
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
|
||||||
|
|||||||
@@ -168,34 +168,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
return config, filtered_inputs_dict
|
return config, filtered_inputs_dict
|
||||||
|
|
||||||
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
|
|
||||||
"""
|
|
||||||
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
|
|
||||||
the following situations:
|
|
||||||
1. The sequences are the same
|
|
||||||
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
|
|
||||||
"""
|
|
||||||
# scores doesn't include data regarding decoder input tokens
|
|
||||||
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
|
||||||
output_matches = output_1.sequences == output_2.sequences
|
|
||||||
has_matching_outputs = output_matches.all()
|
|
||||||
has_matching_scores = None
|
|
||||||
if not has_matching_outputs:
|
|
||||||
for batch_idx in range(output_1.sequences.shape[0]):
|
|
||||||
batch_matches = output_matches[batch_idx]
|
|
||||||
if batch_matches.all():
|
|
||||||
continue
|
|
||||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
|
||||||
first_mismatch_idx -= decoder_input_length
|
|
||||||
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
|
||||||
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
|
||||||
has_matching_scores = torch.allclose(
|
|
||||||
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
|
||||||
)
|
|
||||||
if not has_matching_scores:
|
|
||||||
break
|
|
||||||
self.assertTrue(has_matching_outputs or has_matching_scores)
|
|
||||||
|
|
||||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||||
logits_processor_kwargs = {
|
logits_processor_kwargs = {
|
||||||
"bad_words_ids": [[1, 0]],
|
"bad_words_ids": [[1, 0]],
|
||||||
@@ -1094,7 +1066,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True)
|
low_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=True)
|
||||||
high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False)
|
high_output = model.generate(**inputs_dict, **generate_kwargs, low_memory=False)
|
||||||
self._check_similar_generate_outputs(low_output, high_output)
|
self.assertTrue(has_similar_generate_outputs(low_output, high_output))
|
||||||
|
|
||||||
@parameterized.expand([("random",), ("same",)])
|
@parameterized.expand([("random",), ("same",)])
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@@ -1176,7 +1148,7 @@ class GenerationTesterMixin:
|
|||||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
|
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
|
||||||
|
|
||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
self.assertTrue(has_similar_generate_outputs(output_greedy, output_assisted))
|
||||||
for output in (output_greedy, output_assisted):
|
for output in (output_greedy, output_assisted):
|
||||||
self._check_generate_outputs(output, model.config, use_cache=True)
|
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||||
|
|
||||||
@@ -1259,7 +1231,7 @@ class GenerationTesterMixin:
|
|||||||
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
|
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
|
self.assertTrue(has_similar_generate_outputs(output_greedy, output_prompt_lookup))
|
||||||
for output in (output_greedy, output_prompt_lookup):
|
for output in (output_greedy, output_prompt_lookup):
|
||||||
self._check_generate_outputs(output, model.config, use_cache=True)
|
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||||
|
|
||||||
@@ -1745,7 +1717,7 @@ class GenerationTesterMixin:
|
|||||||
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||||
)
|
)
|
||||||
if not has_complex_embeds_computation:
|
if not has_complex_embeds_computation:
|
||||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
self.assertTrue(has_similar_generate_outputs(outputs_from_ids, outputs_from_embeds))
|
||||||
|
|
||||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||||
# be the same
|
# be the same
|
||||||
@@ -1754,7 +1726,7 @@ class GenerationTesterMixin:
|
|||||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||||
)
|
)
|
||||||
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
||||||
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
|
self.assertTrue(has_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds))
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
@@ -1896,7 +1868,7 @@ class GenerationTesterMixin:
|
|||||||
outputs_cached.scores = full_cached_scores
|
outputs_cached.scores = full_cached_scores
|
||||||
|
|
||||||
# The two sets of generated text and past kv should be equal to each other
|
# The two sets of generated text and past kv should be equal to each other
|
||||||
self._check_similar_generate_outputs(outputs, outputs_cached)
|
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
|
||||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@@ -1923,7 +1895,7 @@ class GenerationTesterMixin:
|
|||||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder")
|
self.skipTest(reason="This model is encoder-decoder")
|
||||||
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
|
# TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
|
||||||
# but it breaks a few models. Fix and then apply `_check_similar_generate_outputs` pattern
|
# but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
@@ -2038,7 +2010,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
self.assertTrue(has_similar_generate_outputs(dynamic_cache_generation, static_cache_generation))
|
||||||
|
|
||||||
@require_optimum_quanto
|
@require_optimum_quanto
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@@ -2192,7 +2164,7 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
self.assertTrue(has_similar_generate_outputs(dynamic_result, compiled_result))
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_compilation_all_outputs(self):
|
def test_generate_compilation_all_outputs(self):
|
||||||
@@ -2381,7 +2353,7 @@ class GenerationTesterMixin:
|
|||||||
del model_attn
|
del model_attn
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@@ -5087,6 +5059,97 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
assert value == "success"
|
assert value == "success"
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_generate_custom_cache_position(self):
|
||||||
|
"""
|
||||||
|
Regression test for #39261. Tests that we can continue generating from past key values, returned from a
|
||||||
|
previous `generate` call, without the tokens that correspond to the cached part. This is achieved by passing
|
||||||
|
manually creating `cache_position` -- this tests that it is piped correctly.
|
||||||
|
"""
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
|
||||||
|
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
|
||||||
|
generate_kwargs = {
|
||||||
|
"use_cache": True,
|
||||||
|
"do_sample": False,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"output_scores": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Traditional way to continue generating text using kv cache
|
||||||
|
# output2
|
||||||
|
# /~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||||
|
# input2
|
||||||
|
# /~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||||
|
# output1
|
||||||
|
# /~~~~~~~~~~~~~~~~\
|
||||||
|
# input1
|
||||||
|
# /~~~~~~\
|
||||||
|
# IIIIIIIIOOOOOOOOOOIIIIIIIIOOOOOOOOOOOOOOOOOO
|
||||||
|
inputs_1a = model_inputs
|
||||||
|
outputs_1a = model.generate(**inputs_1a, **generate_kwargs, max_new_tokens=2)
|
||||||
|
inputs_2a = {**model_inputs}
|
||||||
|
inputs_2a["input_ids"] = torch.cat((outputs_1a.sequences, model_inputs["input_ids"]), dim=1)
|
||||||
|
inputs_2a["attention_mask"] = torch.nn.functional.pad(
|
||||||
|
inputs_1a["attention_mask"],
|
||||||
|
(0, inputs_2a["input_ids"].shape[1] - inputs_1a["input_ids"].shape[1]),
|
||||||
|
mode="constant",
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
inputs_2a["past_key_values"] = outputs_1a.past_key_values
|
||||||
|
outputs_2a = model.generate(**inputs_2a, **generate_kwargs, max_new_tokens=2)
|
||||||
|
# Keep only the part of the output related to the second output + last token from the first output, for future
|
||||||
|
# comparison
|
||||||
|
traditional_outputs = copy.deepcopy(outputs_2a)
|
||||||
|
traditional_outputs.sequences = traditional_outputs.sequences[:, outputs_1a.sequences.shape[1] - 1 :]
|
||||||
|
|
||||||
|
# Continue generating text using kv cache, but without providing the cached part of the input in the input_ids.
|
||||||
|
# cache_position
|
||||||
|
# /~~~~~~~\
|
||||||
|
# inputs2["attention_mask"]
|
||||||
|
# /~~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||||
|
# output1 output2
|
||||||
|
# /~~~~~~~~~~~~~~~~\/~~~~~~~~~~~~~~~~~~~~~~~~~\
|
||||||
|
# input1 input2
|
||||||
|
# /~~~~~~\ /~~~~~~~\
|
||||||
|
# IIIIIIIIOOOOOOOOOOIIIIIIIIIOOOOOOOOOOOOOOOOOO
|
||||||
|
#
|
||||||
|
inputs_1b = model_inputs
|
||||||
|
outputs_1b = model.generate(**inputs_1b, **generate_kwargs, max_new_tokens=2)
|
||||||
|
inputs_2b = {**model_inputs}
|
||||||
|
# The last output token isn't cached, so it needs to be included in the new input
|
||||||
|
inputs_2b["input_ids"] = torch.cat((outputs_1b.sequences[:, -1:], model_inputs["input_ids"]), dim=1)
|
||||||
|
inputs_2b["attention_mask"] = torch.nn.functional.pad(
|
||||||
|
inputs_1b["attention_mask"],
|
||||||
|
(0, outputs_1b.sequences.shape[1]),
|
||||||
|
mode="constant",
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
inputs_2b["past_key_values"] = outputs_1b.past_key_values
|
||||||
|
cache_length_1b = outputs_1b.past_key_values[0][0].shape[-2]
|
||||||
|
inputs_2b["cache_position"] = torch.arange(
|
||||||
|
cache_length_1b,
|
||||||
|
cache_length_1b + inputs_2b["input_ids"].shape[1],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=model.device,
|
||||||
|
)
|
||||||
|
outputs_2b = model.generate(**inputs_2b, **generate_kwargs, max_new_tokens=2)
|
||||||
|
incremental_outputs = outputs_2b
|
||||||
|
|
||||||
|
# The two sets of generated text and past kv should be equal to each other
|
||||||
|
self.assertTrue(has_similar_generate_outputs(traditional_outputs, incremental_outputs))
|
||||||
|
for layer_idx in range(len(traditional_outputs.past_key_values)):
|
||||||
|
for kv_idx in range(len(traditional_outputs.past_key_values[layer_idx])):
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
traditional_outputs.past_key_values[layer_idx][kv_idx],
|
||||||
|
incremental_outputs.past_key_values[layer_idx][kv_idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TokenHealingTestCase(unittest.TestCase):
|
class TokenHealingTestCase(unittest.TestCase):
|
||||||
@@ -5281,3 +5344,41 @@ class TestAssistedCandidateGeneratorUpdateStrategy(unittest.TestCase):
|
|||||||
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
self.assertEqual(self.assistant_model.generation_config.assistant_confidence_threshold, 0.4)
|
||||||
else:
|
else:
|
||||||
self.assert_no_sklearn()
|
self.assert_no_sklearn()
|
||||||
|
|
||||||
|
|
||||||
|
def has_similar_generate_outputs(output_1, output_2, atol=1e-5, rtol=1e-5) -> bool:
|
||||||
|
"""
|
||||||
|
Returns a boolean indicating whether a pair of generate outputs are similar. Two `generate` call outputs are
|
||||||
|
considered similar in the following situations:
|
||||||
|
1. The sequences are the same
|
||||||
|
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_1 (`GenerateOutput`): The first `generate` call output.
|
||||||
|
output_2 (`GenerateOutput`): The second `generate` call output.
|
||||||
|
atol (`float`, *optional*, defaults to 1e-5): The absolute tolerance for the scores.
|
||||||
|
rtol (`float`, *optional*, defaults to 1e-5): The relative tolerance for the scores.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A boolean indicating whether the two generate outputs are similar.
|
||||||
|
"""
|
||||||
|
# scores doesn't include data regarding decoder input tokens
|
||||||
|
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
||||||
|
output_matches = output_1.sequences == output_2.sequences
|
||||||
|
has_matching_outputs = output_matches.all()
|
||||||
|
has_matching_scores = None
|
||||||
|
if not has_matching_outputs:
|
||||||
|
for batch_idx in range(output_1.sequences.shape[0]):
|
||||||
|
batch_matches = output_matches[batch_idx]
|
||||||
|
if batch_matches.all():
|
||||||
|
continue
|
||||||
|
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||||
|
first_mismatch_idx -= decoder_input_length
|
||||||
|
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
||||||
|
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
||||||
|
has_matching_scores = torch.allclose(
|
||||||
|
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
||||||
|
)
|
||||||
|
if not has_matching_scores:
|
||||||
|
break
|
||||||
|
return has_matching_outputs or has_matching_scores
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from transformers.testing_utils import (
|
|||||||
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
|
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
|
||||||
from transformers.utils.import_utils import is_datasets_available
|
from transformers.utils.import_utils import is_datasets_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
@@ -512,7 +512,7 @@ class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
outputs_cached.scores = full_cached_scores
|
outputs_cached.scores = full_cached_scores
|
||||||
|
|
||||||
# The two sets of generated text and past kv should be equal to each other
|
# The two sets of generated text and past kv should be equal to each other
|
||||||
self._check_similar_generate_outputs(outputs, outputs_cached)
|
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
|
||||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from transformers.utils import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||||
@@ -650,7 +650,7 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
outputs_from_embeds = model.generate(
|
outputs_from_embeds = model.generate(
|
||||||
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||||
)
|
)
|
||||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
self.assertTrue(has_similar_generate_outputs(outputs_from_ids, outputs_from_embeds))
|
||||||
|
|
||||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||||
# be the same
|
# be the same
|
||||||
@@ -658,7 +658,7 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||||
)
|
)
|
||||||
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
||||||
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
|
self.assertTrue(has_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds))
|
||||||
|
|
||||||
|
|
||||||
# We will verify our results on an image of cute cats
|
# We will verify our results on an image of cute cats
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import (
|
from ...test_modeling_common import (
|
||||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||||
@@ -527,7 +527,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
outputs_cached.scores = full_cached_scores
|
outputs_cached.scores = full_cached_scores
|
||||||
|
|
||||||
# The two sets of generated text and past kv should be equal to each other
|
# The two sets of generated text and past kv should be equal to each other
|
||||||
self._check_similar_generate_outputs(outputs, outputs_cached)
|
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
|
||||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@@ -613,7 +613,7 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
del model_attn
|
del model_attn
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
self.assertTrue(has_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin, has_similar_generate_outputs
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
@@ -1196,7 +1196,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
outputs_cached.scores = full_cached_scores
|
outputs_cached.scores = full_cached_scores
|
||||||
|
|
||||||
# The two sets of generated text and past kv should be equal to each other
|
# The two sets of generated text and past kv should be equal to each other
|
||||||
self._check_similar_generate_outputs(outputs, outputs_cached)
|
self.assertTrue(has_similar_generate_outputs(outputs, outputs_cached))
|
||||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
Reference in New Issue
Block a user