Tests: move generate tests to the right mixin and delete redundant tests (#34464)
* tmp commit * tmp commit * cull overwrites of deleted tests * typo * more specific docstring * make fixup * parameterize at the top? * correction * more deletions :D * tmp commit * for VLMs too * fix _check_outputs * test nit * make fixup * fix another flaky * test_generate_from_inputs_embeds -- handle missing attention mask
This commit is contained in:
@@ -29,6 +29,7 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_flash_attn,
|
||||
require_optimum_quanto,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@@ -136,6 +137,34 @@ class GenerationTesterMixin:
|
||||
|
||||
return config, filtered_inputs_dict
|
||||
|
||||
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
|
||||
"""
|
||||
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
|
||||
the following siturations:
|
||||
1. The sequences are the same
|
||||
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
|
||||
"""
|
||||
# scores doesn't include data regarding decoder input tokens
|
||||
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
|
||||
output_matches = output_1.sequences == output_2.sequences
|
||||
has_matching_outputs = output_matches.all()
|
||||
has_matching_scores = None
|
||||
if not has_matching_outputs:
|
||||
for batch_idx in range(output_1.sequences.shape[0]):
|
||||
batch_matches = output_matches[batch_idx]
|
||||
if batch_matches.all():
|
||||
continue
|
||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||
first_mismatch_idx -= decoder_input_length
|
||||
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
|
||||
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
|
||||
has_matching_scores = torch.allclose(
|
||||
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
|
||||
)
|
||||
if not has_matching_scores:
|
||||
break
|
||||
self.assertTrue(has_matching_outputs or has_matching_scores)
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {
|
||||
"bad_words_ids": [[1, 0]],
|
||||
@@ -426,7 +455,6 @@ class GenerationTesterMixin:
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
@@ -453,13 +481,12 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(output_generate, main_input, model.config)
|
||||
self._check_outputs(output_generate, model.config)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
@@ -486,7 +513,7 @@ class GenerationTesterMixin:
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
)
|
||||
|
||||
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
||||
self._check_outputs(output_generate, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_sample_generate(self):
|
||||
@@ -505,7 +532,6 @@ class GenerationTesterMixin:
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._sample_generate(
|
||||
@@ -533,7 +559,7 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(output_generate, main_input, model.config, num_return_sequences=2)
|
||||
self._check_outputs(output_generate, model.config, num_return_sequences=2)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_generate(self):
|
||||
@@ -554,7 +580,6 @@ class GenerationTesterMixin:
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -583,14 +608,16 @@ class GenerationTesterMixin:
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
@@ -623,10 +650,10 @@ class GenerationTesterMixin:
|
||||
|
||||
self._check_outputs(
|
||||
output_generate,
|
||||
main_input,
|
||||
model.config,
|
||||
use_cache=True,
|
||||
num_return_sequences=beam_kwargs["num_beams"],
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
)
|
||||
|
||||
@require_accelerate
|
||||
@@ -675,7 +702,6 @@ class GenerationTesterMixin:
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -706,7 +732,10 @@ class GenerationTesterMixin:
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
@@ -765,7 +794,6 @@ class GenerationTesterMixin:
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
@@ -794,7 +822,10 @@ class GenerationTesterMixin:
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
)
|
||||
|
||||
# TODO: @gante check why it is flaky
|
||||
@@ -859,7 +890,6 @@ class GenerationTesterMixin:
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
@@ -899,7 +929,10 @@ class GenerationTesterMixin:
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
@@ -942,7 +975,6 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -968,7 +1000,7 @@ class GenerationTesterMixin:
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
)
|
||||
|
||||
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
||||
self._check_outputs(output_generate, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
@@ -1064,14 +1096,10 @@ class GenerationTesterMixin:
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
||||
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
||||
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
|
||||
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
|
||||
# NOTE: It breaks the pattern in the tests above, for multiple reasons:
|
||||
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
|
||||
# prepare the assistant encoder outputs in the main generate body);
|
||||
# - assisted_decoding does not support `use_cache = False`
|
||||
@@ -1100,7 +1128,6 @@ class GenerationTesterMixin:
|
||||
|
||||
# enable cache
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1141,12 +1168,10 @@ class GenerationTesterMixin:
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
|
||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, main_input, model.config, use_cache=True)
|
||||
self._check_outputs(output, model.config, use_cache=True)
|
||||
|
||||
@is_flaky()
|
||||
@pytest.mark.generate
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
|
||||
@@ -1175,7 +1200,6 @@ class GenerationTesterMixin:
|
||||
|
||||
# enable cache
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1208,10 +1232,9 @@ class GenerationTesterMixin:
|
||||
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
|
||||
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
|
||||
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
|
||||
for output in (output_greedy, output_prompt_lookup):
|
||||
self._check_outputs(output, main_input, model.config, use_cache=True)
|
||||
self._check_outputs(output, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_dola_decoding_sample(self):
|
||||
@@ -1231,7 +1254,6 @@ class GenerationTesterMixin:
|
||||
|
||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
@@ -1259,7 +1281,7 @@ class GenerationTesterMixin:
|
||||
"dola_layers": "low",
|
||||
}
|
||||
output_dola = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self._check_outputs(output_dola, main_input, model.config, use_cache=getattr(config, "use_cache", False))
|
||||
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_sample(self):
|
||||
@@ -1289,7 +1311,6 @@ class GenerationTesterMixin:
|
||||
|
||||
# enable cache
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
main_input = inputs_dict[model_class.main_input_name]
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
@@ -1321,7 +1342,7 @@ class GenerationTesterMixin:
|
||||
}
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
self._check_outputs(output_assisted, main_input, config, use_cache=True)
|
||||
self._check_outputs(output_assisted, config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||
@@ -1547,75 +1568,93 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([(1,), (2,)])
|
||||
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
|
||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
# Ignore:
|
||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||
# which would cause a mismatch),
|
||||
config.pad_token_id = config.eos_token_id = -1
|
||||
# b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
|
||||
# variable that holds the scaling factor, which is model-dependent)
|
||||
if hasattr(config, "scale_embedding"):
|
||||
config.scale_embedding = False
|
||||
|
||||
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
|
||||
# decoder)
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
config.is_decoder = True
|
||||
|
||||
# Skip models without explicit support
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
continue
|
||||
|
||||
# There are a few exception patterns in this test:
|
||||
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
|
||||
requires_inputs_ids = any(
|
||||
model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"]
|
||||
)
|
||||
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
||||
# than calling the embedding layer with `input_ids`. Subcases of this exception:
|
||||
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
||||
if hasattr(config, "scale_embedding"):
|
||||
config.scale_embedding = False
|
||||
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
|
||||
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
|
||||
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"]
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
|
||||
has_complex_embeds_computation = any(
|
||||
model_name in model_class.__name__.lower() for model_name in ["moshi"]
|
||||
)
|
||||
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
|
||||
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
|
||||
missing_attention_mask = "attention_mask" not in inputs_dict
|
||||
|
||||
# Traditional way of generating text
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"num_beams": num_beams,
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
}
|
||||
|
||||
# Traditional way of generating text
|
||||
outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs)
|
||||
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||
# The output of the two calls should be the same.
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(
|
||||
input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
max_new_tokens=5,
|
||||
**generation_kwargs,
|
||||
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
|
||||
if not has_complex_embeds_computation:
|
||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
||||
|
||||
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# same, but the logits will almost surely be different)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(
|
||||
input_ids,
|
||||
inputs_embeds=random_embeds,
|
||||
max_new_tokens=5,
|
||||
**generation_kwargs,
|
||||
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||
|
||||
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, max_new_tokens=5, **generation_kwargs
|
||||
)
|
||||
self.assertListEqual(
|
||||
outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
|
||||
outputs_from_embeds_wo_ids.sequences.tolist(),
|
||||
)
|
||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||
# be the same
|
||||
if not (requires_inputs_ids or missing_attention_mask):
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
||||
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
@@ -1829,10 +1868,8 @@ class GenerationTesterMixin:
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_static_cache(self):
|
||||
"""
|
||||
Tests if StaticCache works if we set attn_implementation=static when generation.
|
||||
This doesn't test if generation quality is good, but tests that models with
|
||||
self._supports_static_cache don't throw an error when generating and return
|
||||
a StaticCache object at the end.
|
||||
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
|
||||
has the expected shapes
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
@@ -1851,13 +1888,15 @@ class GenerationTesterMixin:
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_length": None,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
"output_scores": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
|
||||
|
||||
# Check 1: The cache shapes must match the expected shapes
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
head_dim = (
|
||||
@@ -1869,12 +1908,14 @@ class GenerationTesterMixin:
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
|
||||
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
|
||||
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
|
||||
|
||||
# Check 2: The outputs must be similar to the case with dynamic cache
|
||||
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
|
||||
|
||||
@require_optimum_quanto
|
||||
@pytest.mark.generate
|
||||
@@ -1908,25 +1949,32 @@ class GenerationTesterMixin:
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
|
||||
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
|
||||
]
|
||||
)
|
||||
@pytest.mark.generate
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
|
||||
def test_generate_compile_fullgraph(self):
|
||||
def test_generate_compile(self, _, end_to_end):
|
||||
"""
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
|
||||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
|
||||
end-to-end compilation and forward pass compilation only.
|
||||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache")
|
||||
|
||||
# TODO (joao) -- fix and enable me :)
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
|
||||
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
|
||||
self.skipTest("whisper model end-to-end generate compile not yet supported")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO (joao) -- fix and enable me :)
|
||||
if config.is_encoder_decoder:
|
||||
if end_to_end and config.is_encoder_decoder:
|
||||
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
@@ -1941,27 +1989,33 @@ class GenerationTesterMixin:
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 10,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
# end-to-end works best with dynamic cache, forward compilation works best with static cache
|
||||
if not end_to_end:
|
||||
generation_kwargs["cache_implementation"] = "static"
|
||||
|
||||
max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
|
||||
config = config.get_text_config()
|
||||
past_key_values = StaticCache(
|
||||
config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
|
||||
)
|
||||
|
||||
# get eager + dynamic cache results for future comparison
|
||||
dynamic_outputs = []
|
||||
for model_inputs in input_ids_sets:
|
||||
# eager dynamic cache
|
||||
output_dynamic = model.generate(model_inputs, **generation_kwargs)
|
||||
dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))
|
||||
|
||||
# end-to-end compiled dynamic cache
|
||||
torch.compiler.reset()
|
||||
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.update(**generation_kwargs)
|
||||
output_compiled = compiled_generate(
|
||||
model_inputs, generation_config=generation_config, past_key_values=past_key_values
|
||||
)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||
# get compiled results
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
generation_config.update(**generation_kwargs)
|
||||
torch.compiler.reset()
|
||||
if end_to_end:
|
||||
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
else:
|
||||
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
|
||||
|
||||
compiled_outputs = []
|
||||
for model_inputs in input_ids_sets:
|
||||
compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
|
||||
|
||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_methods_with_num_logits_to_keep(self):
|
||||
@@ -1989,7 +2043,6 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
@@ -1998,6 +2051,9 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -2010,14 +2066,16 @@ class GenerationTesterMixin:
|
||||
"max_new_tokens": 10,
|
||||
"do_sample": False,
|
||||
"assistant_model": assistant_model,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
|
||||
assistant_model.generation_config.assistant_confidence_threshold = None
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_inherits_generation_mixin(self):
|
||||
@@ -2028,14 +2086,21 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
def _test_attention_implementation(self, attn_implementation):
|
||||
"""
|
||||
Compares the output of generate with the eager attention implementation against other implementations.
|
||||
NOTE: despite the test logic being the same, different implementations actually need diferent decorators, hence
|
||||
this separate function.
|
||||
"""
|
||||
max_new_tokens = 30
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn_2",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
if not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
|
||||
|
||||
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
inputs_dict = {}
|
||||
@@ -2062,17 +2127,9 @@ class GenerationTesterMixin:
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_sdpa
|
||||
gc.collect()
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
@@ -2083,42 +2140,46 @@ class GenerationTesterMixin:
|
||||
del model_eager
|
||||
gc.collect()
|
||||
|
||||
# Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this
|
||||
# test would be flaky if we only checked the sequences. Two situations in which this test passes:
|
||||
# 1. The sequences are the same
|
||||
# 2. The sequences are different, but the scores up until the first mismatch are nearly identical
|
||||
output_matches = res_eager.sequences == res_sdpa.sequences
|
||||
has_matching_outputs = output_matches.all()
|
||||
has_matching_scores = None
|
||||
if not has_matching_outputs:
|
||||
input_length = main_input.shape[1]
|
||||
for batch_idx in range(res_eager.sequences.shape[0]):
|
||||
batch_matches = output_matches[batch_idx]
|
||||
if batch_matches.all():
|
||||
continue
|
||||
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
|
||||
first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens
|
||||
sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx]
|
||||
eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx]
|
||||
has_matching_scores = torch.allclose(
|
||||
sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3
|
||||
)
|
||||
if not has_matching_scores:
|
||||
break
|
||||
model_attn = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(torch_device)
|
||||
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
|
||||
del model_attn
|
||||
gc.collect()
|
||||
|
||||
self.assertTrue(has_matching_outputs or has_matching_scores)
|
||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
||||
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
|
||||
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`
|
||||
batch_size = main_input.shape[0]
|
||||
@pytest.mark.generate
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
"""Tests that generate has equivalent outputs with SDPA and eager attention implementations."""
|
||||
self._test_attention_implementation("sdpa")
|
||||
|
||||
@pytest.mark.flash_attn_test
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
|
||||
# TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
|
||||
# check whether we still need the overwrites
|
||||
self._test_attention_implementation("flash_attention_2")
|
||||
|
||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
||||
internal_batch_size = (
|
||||
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
||||
)
|
||||
|
||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
||||
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
@@ -2129,19 +2190,21 @@ class GenerationTesterMixin:
|
||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
||||
|
||||
# unprocessed logits
|
||||
self._check_logits(num_sequences_in_output, output.logits, config=config)
|
||||
self._check_logits(internal_batch_size, output.logits, config=config)
|
||||
|
||||
# Attentions
|
||||
if self.has_attentions:
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, input_batch_size, config, seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
internal_batch_size,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
@@ -2153,7 +2216,7 @@ class GenerationTesterMixin:
|
||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
internal_batch_size,
|
||||
attentions=attentions,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
@@ -2165,12 +2228,12 @@ class GenerationTesterMixin:
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, seq_length
|
||||
output.encoder_hidden_states, input_batch_size, config, seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
internal_batch_size,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
@@ -2182,7 +2245,7 @@ class GenerationTesterMixin:
|
||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
internal_batch_size,
|
||||
hidden_states,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
@@ -2213,7 +2276,7 @@ class GenerationTesterMixin:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
num_sequences_in_output,
|
||||
internal_batch_size,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
|
||||
Reference in New Issue
Block a user