[caches] Raise exception on offloaded static caches + multi device (#37974)
* skip tests on >1 gpu * add todo
This commit is contained in:
@@ -198,19 +198,24 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows
|
||||
|
||||
def _skip_on_uninstalled_cache_dependencies(self, cache_implementation):
|
||||
"""Function to skip tests on missing cache dependencies, given a cache implementation"""
|
||||
def _skip_on_failed_cache_prerequisites(self, cache_implementation):
|
||||
"""Function to skip tests on failed cache prerequisites, given a cache implementation"""
|
||||
# Installed dependencies
|
||||
if cache_implementation == "quantized" and not is_optimum_quanto_available():
|
||||
self.skipTest("Quanto is not available")
|
||||
# Devices
|
||||
if "offloaded" in cache_implementation:
|
||||
has_accelerator = torch_device is not None and torch_device != "cpu"
|
||||
if not has_accelerator:
|
||||
self.skipTest("Offloaded caches require an accelerator")
|
||||
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
||||
if torch.cuda.device_count() != 1:
|
||||
self.skipTest("Offloaded static caches require exactly 1 GPU")
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_batched(self, cache_implementation):
|
||||
"""Sanity check: caches' `.update` function expects batched inputs"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
|
||||
@@ -239,7 +244,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices
|
||||
(an output sequence contains multiple beam indices).
|
||||
"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
if cache_implementation == "offloaded_hybrid_chunked":
|
||||
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
|
||||
# output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly
|
||||
@@ -273,7 +278,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_extra_left_padding(self, cache_implementation):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the cache"""
|
||||
self._skip_on_uninstalled_cache_dependencies(cache_implementation)
|
||||
self._skip_on_failed_cache_prerequisites(cache_implementation)
|
||||
|
||||
EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user