Offloaded cache: fix generate (#34921)

* fix cache impl

* require_torch_gpu

* fix mamba

* fix copies
This commit is contained in:
Raushan Turganbay
2024-11-28 15:05:56 +01:00
committed by GitHub
parent 57ca9e6d2f
commit 5e8c1d713d
6 changed files with 91 additions and 19 deletions

View File

@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
)
)
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
@pytest.mark.generate
def test_generate_with_static_cache(self):
"""