From f2b59c6173191089dadda197554435ce96ae6c84 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 8 May 2025 14:37:36 +0100 Subject: [PATCH] [caches] Raise exception on offloaded static caches + multi device (#37974) * skip tests on >1 gpu * add todo --- src/transformers/cache_utils.py | 14 ++++++++++++++ tests/utils/test_cache_utils.py | 15 ++++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 85a09f03de..005612e82f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -2028,6 +2028,13 @@ class OffloadedHybridCache(HybridChunkedCache): layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ): super().__init__(config, max_batch_size, max_cache_len, device, dtype, layer_device_map) + + # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps + # track of the original device of each layer + unique_devices = set(layer_device_map.values()) + if len(unique_devices) > 1: + raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}") + self.offload_device = torch.device(offload_device) # Create new CUDA stream for parallel prefetching. self._prefetch_stream = torch.cuda.Stream() if torch._C._get_accelerator().type == "cuda" else None @@ -2280,6 +2287,13 @@ class OffloadedStaticCache(StaticCache): layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super(Cache, self).__init__() + + # TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps + # track of the original device of each layer + unique_devices = set(layer_device_map.values()) + if len(unique_devices) > 1: + raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}") + self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 691c5aa535..48cecb52dc 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -198,19 +198,24 @@ class CacheIntegrationTest(unittest.TestCase): ) cls.model.config.sliding_window = 256 # hack to enable the use of caches with sliding windows - def _skip_on_uninstalled_cache_dependencies(self, cache_implementation): - """Function to skip tests on missing cache dependencies, given a cache implementation""" + def _skip_on_failed_cache_prerequisites(self, cache_implementation): + """Function to skip tests on failed cache prerequisites, given a cache implementation""" + # Installed dependencies if cache_implementation == "quantized" and not is_optimum_quanto_available(): self.skipTest("Quanto is not available") + # Devices if "offloaded" in cache_implementation: has_accelerator = torch_device is not None and torch_device != "cpu" if not has_accelerator: self.skipTest("Offloaded caches require an accelerator") + if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]: + if torch.cuda.device_count() != 1: + self.skipTest("Offloaded static caches require exactly 1 GPU") @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_batched(self, cache_implementation): """Sanity check: caches' `.update` function expects batched inputs""" - self._skip_on_uninstalled_cache_dependencies(cache_implementation) + self._skip_on_failed_cache_prerequisites(cache_implementation) EXPECTED_GENERATION = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] @@ -239,7 +244,7 @@ class CacheIntegrationTest(unittest.TestCase): Sanity check: caches' `reorder_cache` is operational. We can confirm this by looking at the beam indices (an output sequence contains multiple beam indices). """ - self._skip_on_uninstalled_cache_dependencies(cache_implementation) + self._skip_on_failed_cache_prerequisites(cache_implementation) if cache_implementation == "offloaded_hybrid_chunked": # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the # output sequence (and the corresponding beam scores, if we add `output_scores=True`) are significantly @@ -273,7 +278,7 @@ class CacheIntegrationTest(unittest.TestCase): @parameterized.expand(TEST_CACHE_IMPLEMENTATIONS) def test_cache_extra_left_padding(self, cache_implementation): """Tests that adding extra left-padding does not affect the generation with the cache""" - self._skip_on_uninstalled_cache_dependencies(cache_implementation) + self._skip_on_failed_cache_prerequisites(cache_implementation) EXPECTED_GENERATION = ["The cat's whiskers are also a sign of anxiety."]