[generate] return Cache object even if passed in a legacy format (#35673)

* generate returns a Cache object by default

* fix tests

* fix test for encoder-decoder models
This commit is contained in:
Joao Gante
2025-01-16 17:06:24 +00:00
committed by GitHub
parent 2818307e93
commit 94af1c0aa2
9 changed files with 36 additions and 156 deletions

View File

@@ -26,7 +26,7 @@ import numpy as np
import pytest
from parameterized import parameterized
from transformers import AutoConfig, is_torch_available, pipeline, set_seed
from transformers import AutoConfig, is_torch_available, pipeline
from transformers.testing_utils import (
is_flaky,
require_accelerate,
@@ -69,7 +69,7 @@ if is_torch_available():
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@@ -1851,75 +1851,6 @@ class GenerationTesterMixin:
)
)
@parameterized.expand([(1, False), (1, True), (4, False)])
@pytest.mark.generate
def test_new_cache_format(self, num_beams, do_sample):
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
# 👉 tests with and without beam search so that we can test with and without cache reordering.
# 👉 tests with and without sampling so we can cover the most common use cases.
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"do_sample": do_sample,
"num_beams": num_beams,
"num_return_sequences": num_beams,
"return_dict_in_generate": True, # Required to return `past_key_values`
"use_cache": True,
}
# Sets seed before calling `generate` for the case with do_sample=True
seed = torch.randint(0, 1000000, (1,)).item()
set_seed(seed)
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
set_seed(seed)
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
# The two sets of generated sequences must match, despite the cache format between forward passes being
# different
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
# The contents of the two caches, when converted to the same format (in both directions!), must match
legacy_cache = legacy_results.past_key_values
new_cache_converted = new_results.past_key_values.to_legacy_cache()
for layer_idx in range(len(legacy_cache)):
for kv_idx in range(len(legacy_cache[layer_idx])):
# TODO: @raushan, please look into this for new cache format
if legacy_cache[layer_idx][kv_idx] != []:
self.assertTrue(
torch.allclose(
legacy_cache[layer_idx][kv_idx],
new_cache_converted[layer_idx][kv_idx],
)
)
new_cache = new_results.past_key_values
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])):
# TODO: @raushan, please look into this for new cache format
if new_cache[layer_idx][kv_idx] != []:
self.assertTrue(
torch.allclose(
new_cache[layer_idx][kv_idx],
legacy_cache_converted[layer_idx][kv_idx],
)
)
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
@@ -2438,11 +2369,11 @@ class GenerationTesterMixin:
)
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, tuple)
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
[True] * len(past_key_values),
)
self.assertIsInstance(past_key_values, (tuple, Cache))
# Encoder-decoder models: pull and verify the decoder cache
if isinstance(past_key_values, EncoderDecoderCache):
past_key_values = past_key_values.self_attention_cache
# (batch, head, seq_length, head_features)
expected_shape = (
@@ -2451,15 +2382,32 @@ class GenerationTesterMixin:
seq_length,
config.hidden_size // config.num_attention_heads,
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
if isinstance(past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in past_key_values.key_cache],
[expected_shape] * len(past_key_values.key_cache),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in past_key_values.value_cache],
[expected_shape] * len(past_key_values.value_cache),
)
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
else:
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
[True] * len(past_key_values),
)
# check shape key, value
self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
[expected_shape] * len(past_key_values),
)
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.