Offloaded cache: fix generate (#34921)
* fix cache impl * require_torch_gpu * fix mamba * fix copies
This commit is contained in:
committed by
GitHub
parent
57ca9e6d2f
commit
5e8c1d713d
@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||
@require_torch_gpu
|
||||
@pytest.mark.generate
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
"use_cache": True,
|
||||
"cache_implementation": cache_implementation,
|
||||
}
|
||||
|
||||
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# Most cache classes have their own tests except for some that are tested here
|
||||
# The ones here do not need special treatment when passing `cache_implementation`
|
||||
# and are not bound to specific models only
|
||||
new_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_static_cache(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user