[caches] Raise exception on offloaded static caches + multi device (#37974)

* skip tests on >1 gpu

* add todo
This commit is contained in:
Joao Gante
2025-05-08 14:37:36 +01:00
committed by GitHub
parent 4279057d70
commit f2b59c6173
2 changed files with 24 additions and 5 deletions

View File

@@ -198,19 +198,24 @@ class CacheIntegrationTest(unittest.TestCase):
)
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
def _skip_on_uninstalled_cache_dependencies(self, cache_implementation):
"""Function to skip tests on missing cache dependencies, given a cache implementation"""
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
# Installed dependencies
if cache_implementation == "quantized" and not is_optimum_quanto_available():
self.skipTest("Quanto is not available")
# Devices
if "offloaded" in cache_implementation:
has_accelerator = torch_device is not None and torch_device != "cpu"
if not has_accelerator:
self.skipTest("Offloaded caches require an accelerator")
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
if torch.cuda.device_count() != 1:
self.skipTest("Offloaded static caches require exactly 1 GPU")
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_batched(self, cache_implementation):
"""Sanity check: caches' `.update` function expects batched inputs"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
self._skip_on_failed_cache_prerequisites(cache_implementation)
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
@@ -239,7 +244,7 @@ class CacheIntegrationTest(unittest.TestCase):
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
(an output sequence contains multiple beam indices).
"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
self._skip_on_failed_cache_prerequisites(cache_implementation)
if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
@@ -273,7 +278,7 @@ class CacheIntegrationTest(unittest.TestCase):
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_extra_left_padding(self, cache_implementation):
"""Tests that adding extra left-padding does not affect the generation with the cache"""
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
self._skip_on_failed_cache_prerequisites(cache_implementation)
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]