[generate] return Cache object even if passed in a legacy format (#35673)
* generate returns a Cache object by default * fix tests * fix test for encoder-decoder models
This commit is contained in:
@@ -2111,9 +2111,6 @@ class GenerationMixin:
|
|||||||
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||||
# - different models have a different cache name expected by the model (default = "past_key_values")
|
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||||
# - `max_length`, prepared above, is used to determine the maximum cache length
|
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||||
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
|
|
||||||
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
|
||||||
user_defined_cache = model_kwargs.get(cache_name)
|
|
||||||
max_cache_length = generation_config.max_length
|
max_cache_length = generation_config.max_length
|
||||||
if (
|
if (
|
||||||
inputs_tensor.shape[1] != input_ids_length
|
inputs_tensor.shape[1] != input_ids_length
|
||||||
@@ -2395,32 +2392,12 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# Convert to legacy cache format if requested
|
# Convert to legacy cache format if requested
|
||||||
if (
|
if (
|
||||||
generation_config.return_legacy_cache is not False # Should check for `True` after v4.47
|
generation_config.return_legacy_cache is True
|
||||||
and not is_torchdynamo_compiling()
|
and not is_torchdynamo_compiling()
|
||||||
and hasattr(result, "past_key_values")
|
and hasattr(result, "past_key_values")
|
||||||
and hasattr(result.past_key_values, "to_legacy_cache")
|
and getattr(result.past_key_values, "to_legacy_cache") is not None
|
||||||
and result.past_key_values.to_legacy_cache is not None
|
|
||||||
):
|
):
|
||||||
# handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type)
|
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||||
should_convert_cache = generation_config.return_legacy_cache
|
|
||||||
is_user_defined_cache = user_defined_cache is not None
|
|
||||||
is_default_cache_type = (
|
|
||||||
type(result.past_key_values) == DynamicCache # noqa E721
|
|
||||||
or (
|
|
||||||
isinstance(result.past_key_values, EncoderDecoderCache)
|
|
||||||
and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721
|
|
||||||
and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not is_user_defined_cache and is_default_cache_type:
|
|
||||||
logger.warning_once(
|
|
||||||
"From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` "
|
|
||||||
"instance instead by default (as opposed to the legacy tuple of tuples format). If you want to "
|
|
||||||
"keep returning the legacy format, please set `return_legacy_cache=True`."
|
|
||||||
)
|
|
||||||
should_convert_cache = True
|
|
||||||
if should_convert_cache:
|
|
||||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _has_unfinished_sequences(
|
def _has_unfinished_sequences(
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoConfig, is_torch_available, pipeline, set_seed
|
from transformers import AutoConfig, is_torch_available, pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
@@ -69,7 +69,7 @@ if is_torch_available():
|
|||||||
SpeechEncoderDecoderModel,
|
SpeechEncoderDecoderModel,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
BeamSampleDecoderOnlyOutput,
|
BeamSampleDecoderOnlyOutput,
|
||||||
BeamSampleEncoderDecoderOutput,
|
BeamSampleEncoderDecoderOutput,
|
||||||
@@ -1851,75 +1851,6 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
@pytest.mark.generate
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
# Tests that generating with the new format is exactly the same as the legacy one (for models that support it).
|
|
||||||
# 👉 tests with and without beam search so that we can test with and without cache reordering.
|
|
||||||
# 👉 tests with and without sampling so we can cover the most common use cases.
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_cache_class:
|
|
||||||
self.skipTest(reason="This model does not support the new cache format")
|
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
generation_kwargs = {
|
|
||||||
"max_new_tokens": 5,
|
|
||||||
"do_sample": do_sample,
|
|
||||||
"num_beams": num_beams,
|
|
||||||
"num_return_sequences": num_beams,
|
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
|
||||||
"use_cache": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Sets seed before calling `generate` for the case with do_sample=True
|
|
||||||
seed = torch.randint(0, 1000000, (1,)).item()
|
|
||||||
set_seed(seed)
|
|
||||||
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
|
||||||
set_seed(seed)
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
cache_cls = EncoderDecoderCache
|
|
||||||
past_key_values = cache_cls(DynamicCache(), DynamicCache())
|
|
||||||
else:
|
|
||||||
cache_cls = DynamicCache
|
|
||||||
past_key_values = cache_cls()
|
|
||||||
|
|
||||||
new_results = model.generate(past_key_values=past_key_values, **generation_kwargs, **inputs_dict)
|
|
||||||
|
|
||||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
|
||||||
# different
|
|
||||||
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
|
||||||
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
|
||||||
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
|
|
||||||
|
|
||||||
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
|
||||||
legacy_cache = legacy_results.past_key_values
|
|
||||||
new_cache_converted = new_results.past_key_values.to_legacy_cache()
|
|
||||||
for layer_idx in range(len(legacy_cache)):
|
|
||||||
for kv_idx in range(len(legacy_cache[layer_idx])):
|
|
||||||
# TODO: @raushan, please look into this for new cache format
|
|
||||||
if legacy_cache[layer_idx][kv_idx] != []:
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(
|
|
||||||
legacy_cache[layer_idx][kv_idx],
|
|
||||||
new_cache_converted[layer_idx][kv_idx],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
new_cache = new_results.past_key_values
|
|
||||||
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
|
|
||||||
for layer_idx in range(len(new_cache)):
|
|
||||||
for kv_idx in range(len(new_cache[layer_idx])):
|
|
||||||
# TODO: @raushan, please look into this for new cache format
|
|
||||||
if new_cache[layer_idx][kv_idx] != []:
|
|
||||||
self.assertTrue(
|
|
||||||
torch.allclose(
|
|
||||||
new_cache[layer_idx][kv_idx],
|
|
||||||
legacy_cache_converted[layer_idx][kv_idx],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@@ -2438,11 +2369,11 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||||
self.assertIsInstance(past_key_values, tuple)
|
self.assertIsInstance(past_key_values, (tuple, Cache))
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
# Encoder-decoder models: pull and verify the decoder cache
|
||||||
[True] * len(past_key_values),
|
if isinstance(past_key_values, EncoderDecoderCache):
|
||||||
)
|
past_key_values = past_key_values.self_attention_cache
|
||||||
|
|
||||||
# (batch, head, seq_length, head_features)
|
# (batch, head, seq_length, head_features)
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
@@ -2451,15 +2382,32 @@ class GenerationTesterMixin:
|
|||||||
seq_length,
|
seq_length,
|
||||||
config.hidden_size // config.num_attention_heads,
|
config.hidden_size // config.num_attention_heads,
|
||||||
)
|
)
|
||||||
# check shape key, value
|
|
||||||
self.assertListEqual(
|
if isinstance(past_key_values, Cache):
|
||||||
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
self.assertListEqual(
|
||||||
[expected_shape] * len(past_key_values),
|
[key_tensor.shape for key_tensor in past_key_values.key_cache],
|
||||||
)
|
[expected_shape] * len(past_key_values.key_cache),
|
||||||
self.assertListEqual(
|
)
|
||||||
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
self.assertListEqual(
|
||||||
[expected_shape] * len(past_key_values),
|
[value_tensor.shape for value_tensor in past_key_values.value_cache],
|
||||||
)
|
[expected_shape] * len(past_key_values.value_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||||
|
else:
|
||||||
|
self.assertListEqual(
|
||||||
|
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
||||||
|
[True] * len(past_key_values),
|
||||||
|
)
|
||||||
|
# check shape key, value
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
||||||
|
[expected_shape] * len(past_key_values),
|
||||||
|
)
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
||||||
|
[expected_shape] * len(past_key_values),
|
||||||
|
)
|
||||||
|
|
||||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||||
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
# check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1.
|
||||||
|
|||||||
@@ -268,18 +268,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
|||||||
def test_sdpa_can_dispatch_on_flash(self):
|
def test_sdpa_can_dispatch_on_flash(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="")
|
|
||||||
def test_new_cache_format_0(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="")
|
|
||||||
def test_new_cache_format_1(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="")
|
|
||||||
def test_new_cache_format_2(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import inspect
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -395,11 +394,6 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip(reason="Bamba has its own special cache type")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_batching_equivalence(self):
|
def test_batching_equivalence(self):
|
||||||
# need to disable the tril input mask
|
# need to disable the tril input mask
|
||||||
orig = self.model_tester.use_input_mask
|
orig = self.model_tester.use_input_mask
|
||||||
|
|||||||
@@ -103,11 +103,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||||||
def test_dola_decoding_sample(self):
|
def test_dola_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
@unittest.skip("Cohere2 has HybridCache and doesn't support old tuple format at all")
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv")
|
@unittest.skip("Cohere2 has HybridCache and doesn't support continue from past kv")
|
||||||
def test_generate_continue_from_past_key_values(self):
|
def test_generate_continue_from_past_key_values(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -117,11 +117,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
def test_dola_decoding_sample(self):
|
def test_dola_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
@unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all")
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
|
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
|
||||||
def test_generate_continue_from_past_key_values(self):
|
def test_generate_continue_from_past_key_values(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
from transformers import AutoTokenizer, JambaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -550,11 +549,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
"""
|
"""
|
||||||
self.skipTest(reason="Jamba flash attention does not support right padding")
|
self.skipTest(reason="Jamba flash attention does not support right padding")
|
||||||
|
|
||||||
@unittest.skip(reason="Jamba has its own special cache type")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class JambaModelIntegrationTest(unittest.TestCase):
|
class JambaModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import gc
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, JetMoeConfig, is_torch_available
|
from transformers import AutoTokenizer, JetMoeConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -299,10 +298,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
test_disk_offload_bin = False
|
test_disk_offload_bin = False
|
||||||
test_disk_offload_safetensors = False
|
test_disk_offload_safetensors = False
|
||||||
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = JetMoeModelTester(self)
|
self.model_tester = JetMoeModelTester(self)
|
||||||
self.config_tester = ConfigTester(
|
self.config_tester = ConfigTester(
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, ZambaConfig, is_torch_available
|
from transformers import AutoTokenizer, ZambaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -551,11 +550,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
"""
|
"""
|
||||||
self.skipTest(reason="Zamba flash attention does not support right padding")
|
self.skipTest(reason="Zamba flash attention does not support right padding")
|
||||||
|
|
||||||
@unittest.skip(reason="Zamba has its own special cache type")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ZambaModelIntegrationTest(unittest.TestCase):
|
class ZambaModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user