Tests: move generate tests to the right mixin and delete redundant tests (#34464)

* tmp commit

* tmp commit

* cull overwrites of deleted tests

* typo

* more specific docstring

* make fixup

* parameterize at the top?

* correction

* more deletions :D

* tmp commit

* for VLMs too

* fix _check_outputs

* test nit

* make fixup

* fix another flaky

* test_generate_from_inputs_embeds -- handle missing attention mask
This commit is contained in:
Joao Gante
2024-10-30 10:59:08 +00:00
committed by GitHub
parent 913330ca9f
commit 8a734ea2c3
46 changed files with 265 additions and 2348 deletions

View File

@@ -29,6 +29,7 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_flash_attn,
require_optimum_quanto,
require_torch,
require_torch_gpu,
@@ -136,6 +137,34 @@ class GenerationTesterMixin:
return config, filtered_inputs_dict
def _check_similar_generate_outputs(self, output_1, output_2, atol=1e-5, rtol=1e-5):
"""
Checks whether a pair of generate outputs are similar. Two `generate` call outputs are considered similar in
the following siturations:
1. The sequences are the same
2. The sequences are different, but the scores up to (and including) the first mismatch are nearly identical
"""
# scores doesn't include data regarding decoder input tokens
decoder_input_length = output_1.sequences.shape[1] - len(output_1.scores)
output_matches = output_1.sequences == output_2.sequences
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
for batch_idx in range(output_1.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= decoder_input_length
output_1_first_mismatch_scores = output_1.scores[first_mismatch_idx][batch_idx]
output_2_first_mismatch_scores = output_2.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
output_1_first_mismatch_scores, output_2_first_mismatch_scores, rtol=atol, atol=rtol
)
if not has_matching_scores:
break
self.assertTrue(has_matching_outputs or has_matching_scores)
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
@@ -426,7 +455,6 @@ class GenerationTesterMixin:
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate(
@@ -453,13 +481,12 @@ class GenerationTesterMixin:
# Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self._check_outputs(output_generate, main_input, model.config)
self._check_outputs(output_generate, model.config)
@pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -486,7 +513,7 @@ class GenerationTesterMixin:
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
)
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
self._check_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate
def test_sample_generate(self):
@@ -505,7 +532,6 @@ class GenerationTesterMixin:
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate(
@@ -533,7 +559,7 @@ class GenerationTesterMixin:
# Retrocompatibility check
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
self._check_outputs(output_generate, main_input, model.config, num_return_sequences=2)
self._check_outputs(output_generate, model.config, num_return_sequences=2)
@pytest.mark.generate
def test_beam_search_generate(self):
@@ -554,7 +580,6 @@ class GenerationTesterMixin:
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@@ -583,14 +608,16 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -623,10 +650,10 @@ class GenerationTesterMixin:
self._check_outputs(
output_generate,
main_input,
model.config,
use_cache=True,
num_return_sequences=beam_kwargs["num_beams"],
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@require_accelerate
@@ -675,7 +702,6 @@ class GenerationTesterMixin:
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
@@ -706,7 +732,10 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
@@ -765,7 +794,6 @@ class GenerationTesterMixin:
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs()
@@ -794,7 +822,10 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
# TODO: @gante check why it is flaky
@@ -859,7 +890,6 @@ class GenerationTesterMixin:
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
model = model_class(config).to(torch_device).eval()
@@ -899,7 +929,10 @@ class GenerationTesterMixin:
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, main_input, model.config, num_return_sequences=beam_kwargs["num_beams"]
output_generate,
model.config,
num_return_sequences=beam_kwargs["num_return_sequences"],
num_beams=beam_kwargs["num_beams"],
)
@pytest.mark.generate
@@ -942,7 +975,6 @@ class GenerationTesterMixin:
self.skipTest(reason="Won't fix: old model with different cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -968,7 +1000,7 @@ class GenerationTesterMixin:
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
)
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
self._check_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate
def test_contrastive_generate_low_memory(self):
@@ -1064,14 +1096,10 @@ class GenerationTesterMixin:
@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
# shape differences -- and it may result in a different output. The input shape difference happens in the
# main model, that runs the forward pass with several candidates at once (as opposed to generating one token at
# a time). See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE (2): It breaks the pattern in the tests above, for multiple reasons:
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
# NOTE: It breaks the pattern in the tests above, for multiple reasons:
# - assisted_decoding, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_decoding does not support `use_cache = False`
@@ -1100,7 +1128,6 @@ class GenerationTesterMixin:
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -1141,12 +1168,10 @@ class GenerationTesterMixin:
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist())
self._check_similar_generate_outputs(output_greedy, output_assisted)
for output in (output_greedy, output_assisted):
self._check_outputs(output, main_input, model.config, use_cache=True)
self._check_outputs(output, model.config, use_cache=True)
@is_flaky()
@pytest.mark.generate
def test_prompt_lookup_decoding_matches_greedy_search(self):
# This test ensures that the prompt lookup generation does not introduce output changes over greedy search.
@@ -1175,7 +1200,6 @@ class GenerationTesterMixin:
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -1208,10 +1232,9 @@ class GenerationTesterMixin:
output_prompt_lookup = model.generate(**generation_kwargs, **inputs_dict)
# The two outputs must match and their shape must be as expected
self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist())
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, main_input, model.config, use_cache=True)
self._check_outputs(output, model.config, use_cache=True)
@pytest.mark.generate
def test_dola_decoding_sample(self):
@@ -1231,7 +1254,6 @@ class GenerationTesterMixin:
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
main_input = inputs_dict[model_class.main_input_name]
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
@@ -1259,7 +1281,7 @@ class GenerationTesterMixin:
"dola_layers": "low",
}
output_dola = model.generate(**generation_kwargs, **inputs_dict)
self._check_outputs(output_dola, main_input, model.config, use_cache=getattr(config, "use_cache", False))
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
@pytest.mark.generate
def test_assisted_decoding_sample(self):
@@ -1289,7 +1311,6 @@ class GenerationTesterMixin:
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
main_input = inputs_dict[model_class.main_input_name]
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
@@ -1321,7 +1342,7 @@ class GenerationTesterMixin:
}
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
self._check_outputs(output_assisted, main_input, config, use_cache=True)
self._check_outputs(output_assisted, config, use_cache=True)
@pytest.mark.generate
def test_prompt_lookup_decoding_stops_at_eos(self):
@@ -1547,75 +1568,93 @@ class GenerationTesterMixin:
)
@pytest.mark.generate
@parameterized.expand([(1,), (2,)])
def test_generate_from_inputs_embeds_decoder_only(self, num_beams):
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
# if fails, you should probably update the `prepare_inputs_for_generation` function
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
# b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the
# variable that holds the scaling factor, which is model-dependent)
if hasattr(config, "scale_embedding"):
config.scale_embedding = False
# This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
# decoder)
if config.is_encoder_decoder:
continue
config.is_decoder = True
# Skip models without explicit support
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
continue
# There are a few exception patterns in this test:
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
requires_inputs_ids = any(
model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"]
)
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
# than calling the embedding layer with `input_ids`. Subcases of this exception:
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
if hasattr(config, "scale_embedding"):
config.scale_embedding = False
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower()
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma"]
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
has_complex_embeds_computation = any(
model_name in model_class.__name__.lower() for model_name in ["moshi"]
)
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
missing_attention_mask = "attention_mask" not in inputs_dict
# Traditional way of generating text
input_ids = inputs_dict.pop("input_ids")
generation_kwargs = {
"return_dict_in_generate": True,
"output_scores": True,
"num_beams": num_beams,
"do_sample": False,
"max_new_tokens": 5,
"min_new_tokens": 5, # generate exactly 5 tokens
}
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids, max_new_tokens=5, **generation_kwargs)
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
# The output of the two calls should be the same.
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
**generation_kwargs,
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
if not has_complex_embeds_computation:
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
**generation_kwargs,
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, max_new_tokens=5, **generation_kwargs
)
self.assertListEqual(
outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(),
outputs_from_embeds_wo_ids.sequences.tolist(),
)
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
# be the same
if not (requires_inputs_ids or missing_attention_mask):
outputs_from_embeds_wo_ids = model.generate(
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
)
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
@pytest.mark.generate
def test_generate_from_inputs_embeds_with_static_cache(self):
@@ -1829,10 +1868,8 @@ class GenerationTesterMixin:
@pytest.mark.generate
def test_generate_with_static_cache(self):
"""
Tests if StaticCache works if we set attn_implementation=static when generation.
This doesn't test if generation quality is good, but tests that models with
self._supports_static_cache don't throw an error when generating and return
a StaticCache object at the end.
Tests that generating with static cache give almost same results as with dynamic cache, and the output cache
has the expected shapes
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
@@ -1851,13 +1888,15 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_length": None,
"max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
"output_scores": True,
"use_cache": True,
}
static_cache_generation = model.generate(**generation_kwargs, **inputs_dict, cache_implementation="static")
# Check 1: The cache shapes must match the expected shapes
max_cache_len = seq_length + max_new_tokens
config = config.text_config if hasattr(config, "text_config") else config
head_dim = (
@@ -1869,12 +1908,14 @@ class GenerationTesterMixin:
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
results = model.generate(**generation_kwargs, **inputs_dict)
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
self.assertTrue(isinstance(results.past_key_values, StaticCache))
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
self.assertTrue(isinstance(static_cache_generation.past_key_values, StaticCache))
self.assertTrue(len(static_cache_generation.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(static_cache_generation.past_key_values.key_cache[0].shape == cache_shape)
# Check 2: The outputs must be similar to the case with dynamic cache
dynamic_cache_generation = model.generate(**generation_kwargs, **inputs_dict)
self._check_similar_generate_outputs(dynamic_cache_generation, static_cache_generation)
@require_optimum_quanto
@pytest.mark.generate
@@ -1908,25 +1949,32 @@ class GenerationTesterMixin:
with self.assertRaises(ValueError):
model.generate(**generation_kwargs, **inputs_dict)
@parameterized.expand(
[
("forward_only", False), # TODO (@joao): a few models failing. After fixed, this should not be "@slow"
("end_to_end", True), # TODO (@joao): end-to-end compilation is broken with torch 2.5+, explore and fix
]
)
@pytest.mark.generate
@require_torch_gpu
@slow
@is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in `generate` to be flaky
def test_generate_compile_fullgraph(self):
def test_generate_compile(self, _, end_to_end):
"""
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests
end-to-end compilation and forward pass compilation only.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache")
# TODO (joao) -- fix and enable me :)
if any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
if end_to_end and any(model_name in model_class.__name__.lower() for model_name in ["whisper"]):
self.skipTest("whisper model end-to-end generate compile not yet supported")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# TODO (joao) -- fix and enable me :)
if config.is_encoder_decoder:
if end_to_end and config.is_encoder_decoder:
self.skipTest("Encoder-decoder model end-to-end generate compile not yet supported")
model = model_class(config).to(torch_device)
@@ -1941,27 +1989,33 @@ class GenerationTesterMixin:
generation_kwargs = {
"do_sample": False,
"max_new_tokens": 10,
"return_dict_in_generate": True,
"output_scores": True,
}
# end-to-end works best with dynamic cache, forward compilation works best with static cache
if not end_to_end:
generation_kwargs["cache_implementation"] = "static"
max_cache_len = input_ids.shape[1] + generation_kwargs["max_new_tokens"]
config = config.get_text_config()
past_key_values = StaticCache(
config, batch_size=half_batch_size, max_cache_len=max_cache_len, device=torch_device
)
# get eager + dynamic cache results for future comparison
dynamic_outputs = []
for model_inputs in input_ids_sets:
# eager dynamic cache
output_dynamic = model.generate(model_inputs, **generation_kwargs)
dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs))
# end-to-end compiled dynamic cache
torch.compiler.reset()
compiled_generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
output_compiled = compiled_generate(
model_inputs, generation_config=generation_config, past_key_values=past_key_values
)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
# get compiled results
generation_config = copy.deepcopy(model.generation_config)
generation_config.update(**generation_kwargs)
torch.compiler.reset()
if end_to_end:
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
else:
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
compiled_outputs = []
for model_inputs in input_ids_sets:
compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config))
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)
@pytest.mark.generate
def test_generate_methods_with_num_logits_to_keep(self):
@@ -1989,7 +2043,6 @@ class GenerationTesterMixin:
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
@pytest.mark.generate
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
@@ -1998,6 +2051,9 @@ class GenerationTesterMixin:
self.skipTest(reason="Stateful models don't support assisted generation")
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True
@@ -2010,14 +2066,16 @@ class GenerationTesterMixin:
"max_new_tokens": 10,
"do_sample": False,
"assistant_model": assistant_model,
"return_dict_in_generate": True,
"output_scores": True,
}
assistant_model.generation_config.assistant_confidence_threshold = None
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
@pytest.mark.generate
def test_inherits_generation_mixin(self):
@@ -2028,14 +2086,21 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
def _test_attention_implementation(self, attn_implementation):
"""
Compares the output of generate with the eager attention implementation against other implementations.
NOTE: despite the test logic being the same, different implementations actually need diferent decorators, hence
this separate function.
"""
max_new_tokens = 30
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn_2",
}
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
if not getattr(model_class, support_flag[attn_implementation]):
self.skipTest(f"{model_class.__name__} does not support `attn_implementation={attn_implementation}`")
config, original_inputs_dict = self.prepare_config_and_inputs_for_generate()
inputs_dict = {}
@@ -2062,17 +2127,9 @@ class GenerationTesterMixin:
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
"use_cache": True,
}
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(torch_device)
res_sdpa = model_sdpa.generate(**inputs_dict, **generate_kwargs)
del model_sdpa
gc.collect()
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
@@ -2083,42 +2140,46 @@ class GenerationTesterMixin:
del model_eager
gc.collect()
# Eager and SDPA are very similar, but not exactly the same. Because we are using random models, this
# test would be flaky if we only checked the sequences. Two situations in which this test passes:
# 1. The sequences are the same
# 2. The sequences are different, but the scores up until the first mismatch are nearly identical
output_matches = res_eager.sequences == res_sdpa.sequences
has_matching_outputs = output_matches.all()
has_matching_scores = None
if not has_matching_outputs:
input_length = main_input.shape[1]
for batch_idx in range(res_eager.sequences.shape[0]):
batch_matches = output_matches[batch_idx]
if batch_matches.all():
continue
first_mismatch_idx = batch_matches.int().argmin() # gets the index of the first False
first_mismatch_idx -= input_length # scores doesn't include data regarding input tokens
sdpa_first_mismatch_scores = res_sdpa.scores[first_mismatch_idx][batch_idx]
eager_first_mismatch_scores = res_eager.scores[first_mismatch_idx][batch_idx]
has_matching_scores = torch.allclose(
sdpa_first_mismatch_scores, eager_first_mismatch_scores, rtol=1e-3, atol=1e-3
)
if not has_matching_scores:
break
model_attn = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation=attn_implementation,
).to(torch_device)
res_attn = model_attn.generate(**inputs_dict, **generate_kwargs)
del model_attn
gc.collect()
self.assertTrue(has_matching_outputs or has_matching_scores)
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`
batch_size = main_input.shape[0]
@pytest.mark.generate
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
"""Tests that generate has equivalent outputs with SDPA and eager attention implementations."""
self._test_attention_implementation("sdpa")
@pytest.mark.flash_attn_test
@require_flash_attn
@require_torch_gpu
@slow
def test_eager_matches_fa2_generate(self):
"""Tests that generate has equivalent outputs with FA2 and eager attention implementations."""
# TODO (@joao @raushan) -- this test is failing the output checks on most models, investigate. After fixing,
# check whether we still need the overwrites
self._test_attention_implementation("flash_attention_2")
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = (
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
)
seq_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
config = config.text_config if hasattr(config, "text_config") else config
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
@@ -2129,19 +2190,21 @@ class GenerationTesterMixin:
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
# unprocessed logits
self._check_logits(num_sequences_in_output, output.logits, config=config)
self._check_logits(internal_batch_size, output.logits, config=config)
# Attentions
if self.has_attentions:
if config.is_encoder_decoder:
# encoder
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
self._check_encoder_attention_for_generate(
output.encoder_attentions, input_batch_size, config, seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
internal_batch_size,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
@@ -2153,7 +2216,7 @@ class GenerationTesterMixin:
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
num_sequences_in_output,
internal_batch_size,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
@@ -2165,12 +2228,12 @@ class GenerationTesterMixin:
if config.is_encoder_decoder:
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, seq_length
output.encoder_hidden_states, input_batch_size, config, seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
internal_batch_size,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
@@ -2182,7 +2245,7 @@ class GenerationTesterMixin:
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
num_sequences_in_output,
internal_batch_size,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
@@ -2213,7 +2276,7 @@ class GenerationTesterMixin:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
num_sequences_in_output,
internal_batch_size,
past_key_values,
seq_length=past_sequence_length,
config=config,