From 373e50e9703024eeb09f97945fdefde80acc2619 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 22 Jan 2025 09:49:17 +0100 Subject: [PATCH] Init cache on meta device (#35164) * init cache on meta device * offloaded static + enable tests * tests weren't running before :( * update * fix mamba * fix copies * update * address comments and fix tests * fix copies * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * mamba fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/cache_utils.py | 129 +++++++++++++----- src/transformers/generation/utils.py | 35 +---- src/transformers/integrations/executorch.py | 1 + .../models/cohere2/modeling_cohere2.py | 1 - .../models/cohere2/modular_cohere2.py | 1 - .../models/gemma2/modeling_gemma2.py | 1 - .../models/gemma2/modular_gemma2.py | 1 - tests/models/llama/test_modeling_llama.py | 12 +- tests/test_modeling_common.py | 10 ++ tests/utils/test_cache_utils.py | 31 +---- 10 files changed, 111 insertions(+), 111 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index e616adbe67..b2be3f238d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1069,12 +1069,15 @@ class StaticCache(Cache): The maximum sequence length with which the model will be used. device (`torch.device` or `str`): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. + Example: ```python @@ -1096,6 +1099,7 @@ class StaticCache(Cache): """ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, @@ -1122,6 +1126,7 @@ class StaticCache(Cache): ) self.dtype = dtype + self.device = torch.device(device) if device is not None else torch.device("meta") self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1136,7 +1141,7 @@ class StaticCache(Cache): if layer_device_map is not None: layer_device = layer_device_map[idx] else: - layer_device = device + layer_device = self.device new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) # Notes: @@ -1181,6 +1186,9 @@ class StaticCache(Cache): """ cache_position = cache_kwargs.get("cache_position") + if self.key_cache[layer_idx].device.type == "meta": + self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) + self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1209,6 +1217,8 @@ class StaticCache(Cache): # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` + if self.key_cache[layer_idx].device.type == "meta": + return 0 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_cache_shape(self) -> Optional[int]: @@ -1217,9 +1227,10 @@ class StaticCache(Cache): def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + if self.key_cache[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() @property def batch_size(self): @@ -1257,6 +1268,8 @@ class SlidingWindowCache(StaticCache): The maximum sequence length with which the model will be used. device (`torch.device` or `str`): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): @@ -1321,8 +1334,15 @@ class SlidingWindowCache(StaticCache): cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor]: cache_position = cache_kwargs.get("cache_position") + + if self.key_cache[layer_idx].device.type == "meta": + self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) + self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) + k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) if cache_position.shape[0] > self.max_cache_len: @@ -1365,9 +1385,10 @@ class SlidingWindowCache(StaticCache): def reset(self): for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + if self.key_cache[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() class EncoderDecoderCache(Cache): @@ -1561,8 +1582,10 @@ class HybridCache(Cache): smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*, defaults to `"cpu"`): + device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. dtype (torch.dtype, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): @@ -1590,12 +1613,13 @@ class HybridCache(Cache): """ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, - device: Union[torch.device, str] = "cpu", + device: Union[torch.device, str] = None, dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, @@ -1623,9 +1647,11 @@ class HybridCache(Cache): self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) + + self.device = torch.device(device) if device is not None else torch.device("meta") layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC self.is_sliding = torch.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device + [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] @@ -1640,7 +1666,7 @@ class HybridCache(Cache): if layer_device_map is not None: layer_device = layer_device_map[i] else: - layer_device = device + layer_device = self.device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape @@ -1696,8 +1722,16 @@ class HybridCache(Cache): ) -> Tuple[torch.Tensor]: cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window") + + if self.key_cache[layer_idx].device.type == "meta": + self.key_cache[layer_idx] = torch.zeros_like(self.key_cache[layer_idx], device=key_states.device) + self.value_cache[layer_idx] = torch.zeros_like(self.value_cache[layer_idx], device=value_states.device) + k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + if sliding_window: update_fn = self._sliding_update else: @@ -1725,14 +1759,18 @@ class HybridCache(Cache): "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " "Using the `layer_idx` argument is not supported." ) + + if self.key_cache[layer_idx].device.type == "meta": + return 0 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() + if self.key_cache[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() @property def batch_size(self): @@ -1757,10 +1795,14 @@ class MambaCache: The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. Should be the same as the layer. + The recommended way however is not not indicate any `device`, in that case cache will be initialized on `meta` + device by default, and then moved to input device when updating. Attributes: dtype: (`torch.dtype`): The default `dtype` used to initializing the cache. + device (`torch.device`): + The default device on which the cache was initialized. intermediate_size: (`int`): Model's intermediate_size taken from config. ssm_state_size: (`int`): @@ -1809,30 +1851,40 @@ class MambaCache: self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel + self.device = torch.device(device) if device is not None else torch.device("meta") - self.conv_states: torch.Tensor = torch.zeros( - config.num_hidden_layers, - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=dtype, - ) - self.ssm_states: torch.Tensor = torch.zeros( - config.num_hidden_layers, - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=dtype, - ) + self.conv_states: List[torch.Tensor] = [] + self.ssm_states: List[torch.Tensor] = [] + for _ in range(config.num_hidden_layers): + conv_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=self.device, + dtype=dtype, + ) + ssm_state: torch.Tensor = torch.zeros( + self.max_batch_size, + self.intermediate_size, + self.ssm_state_size, + device=self.device, + dtype=dtype, + ) - torch._dynamo.mark_static_address(self.conv_states) - torch._dynamo.mark_static_address(self.ssm_states) + torch._dynamo.mark_static_address(conv_state) + torch._dynamo.mark_static_address(ssm_state) + self.conv_states.append(conv_state) + self.ssm_states.append(ssm_state) def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor ) -> torch.Tensor: + if self.conv_states[layer_idx].device.type == "meta": + self.conv_states[layer_idx] = torch.zeros_like( + self.conv_states[layer_idx], + device=new_conv_state.device, + ) + conv_state = self.conv_states[layer_idx] cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) @@ -1843,12 +1895,15 @@ class MambaCache: return self.conv_states[layer_idx] def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) return self.ssm_states[layer_idx] def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() + for layer_idx in range(len(self.conv_states)): + if self.conv_states[layer_idx].device.type != "meta": + # In-place ops prevent breaking the static address + self.conv_states[layer_idx].zero_() + self.ssm_states[layer_idx].zero_() @property def batch_size(self): @@ -1920,6 +1975,7 @@ class OffloadedStaticCache(StaticCache): ``` """ + @deprecate_kwarg("layer_device_map", version="4.52.0") def __init__( self, config: PretrainedConfig, @@ -1930,9 +1986,10 @@ class OffloadedStaticCache(StaticCache): offload_device: Union[str, torch.device] = torch.device("cpu"), layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: + super(Cache, self).__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] + self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) self.offload_device = torch.device(offload_device) self.dtype = dtype if dtype is not None else torch.float32 diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 655a388cb7..461d7e1215 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1633,45 +1633,12 @@ class GenerationMixin: # models. May cause trobles with non-text modalities. cache_dtype = self.get_output_embeddings().weight.dtype - def get_layer_device_map(execution_device_map: Optional[dict] = None): - num_hidden_layers = self.config.get_text_config().num_hidden_layers - if execution_device_map is None: - return None - elif len(execution_device_map) == 1 and "" in execution_device_map: - return {idx: execution_device_map[""] for idx in range(num_hidden_layers)} - layer_device_map = {} - for layer in execution_device_map: - for idx in range(num_hidden_layers): - if f".{idx}." in f"{layer}.": - layer_device_map[idx] = execution_device_map[layer] - break - for idx in range(num_hidden_layers): - if idx not in layer_device_map: - raise RuntimeError(f"layer {idx} has not been mapped to a device.") - return layer_device_map - - execution_device_map = None - # Taken from dispatch_model from accelerate. - # This is needed here if we don't want to make changes in accelerate in order to save execution_device - # For offloaded case, we need to get the execution device, not just the device where it is offloaded - if hasattr(self, "hf_device_map"): - if set(self.hf_device_map.values()) == {"cpu"} or set(self.hf_device_map.values()) == {"cpu", "disk"}: - main_device = "cpu" - else: - main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0] - execution_device_map = { - name: main_device if device in ["cpu", "disk"] else device - for name, device in self.hf_device_map.items() - } - layer_device_map = get_layer_device_map(execution_device_map) - cache_kwargs = { "config": self.config.get_text_config(), "max_batch_size": batch_size, "max_cache_len": max_cache_len, - "device": device, "dtype": cache_dtype, - "layer_device_map": layer_device_map, + "device": device if cache_implementation == "offloaded_static" else None, } self._cache = cache_cls(**cache_kwargs) if requires_cross_attention_cache: diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 258017f141..a0cbc8ba4e 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -73,6 +73,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, dtype=self.model.dtype, + device=self.model.generation_config.cache_config.device, ) self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) if self.is_causal: diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 0b38c89d75..15469577fb 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -582,7 +582,6 @@ class Cohere2Model(Cohere2PreTrainedModel): self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 78419e78c0..7020df2702 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -461,7 +461,6 @@ class Cohere2Model(Gemma2Model): self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e64559b266..fb7e59051a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -579,7 +579,6 @@ class Gemma2Model(Gemma2PreTrainedModel): self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5f21fc6bff..53a947eb95 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -405,7 +405,6 @@ class Gemma2Model(GemmaModel): self.config, max_batch_size=batch_size, max_cache_len=seq_len, - device=self.device, dtype=inputs_embeds.dtype, ) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 664616306d..8d492ce673 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -728,22 +728,13 @@ class LlamaIntegrationTest(unittest.TestCase): dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) - # Static Cache + # Static Cache + compile (`generate()` internally compiles each decoding step when static cache is used) generated_ids = model.generate( **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) - # Static Cache + compile - model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"` - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" - ) - static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) - @slow @require_read_token def test_export_static_cache(self): @@ -795,6 +786,7 @@ class LlamaIntegrationTest(unittest.TestCase): cache_config={ "batch_size": batch_size, "max_cache_len": max_generation_length, + "device": device, }, ), ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cf259fabe3..d361378503 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4635,6 +4635,11 @@ class ModelTesterMixin: fa2_correctly_converted = True break + fa2_correctly_converted = ( + fa2_correctly_converted + if not model_class._supports_flex_attn + else fa2_model.config._attn_implementation == "flash_attention_2" + ) self.assertTrue(fa2_correctly_converted) _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) @@ -4653,6 +4658,11 @@ class ModelTesterMixin: fa2_correctly_converted = True break + fa2_correctly_converted = ( + fa2_correctly_converted + if not model_class._supports_flex_attn + else model_from_pretrained.config._attn_implementation == "flash_attention_2" + ) self.assertFalse(fa2_correctly_converted) def _get_custom_4d_mask_test_data(self): diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 053d2cf639..d67b026638 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -198,6 +198,7 @@ class CacheTest(unittest.TestCase): cache_config={ "batch_size": batch_size, "max_cache_len": max_cache_len, + "device": device, }, ), ) @@ -310,11 +311,12 @@ class CacheIntegrationTest(unittest.TestCase): do_sample=False, max_new_tokens=20, num_return_sequences=2, + num_beams=2, ) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) expected_text = [ - "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", - "Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many", + "Hello I am doing a project for my school and I am trying to make a program that will allow me to input a", + "Hello I am doing a project for my school and I am trying to make a program that will allow me to use a", ] self.assertListEqual(decoded, expected_text) @@ -380,8 +382,6 @@ class CacheIntegrationTest(unittest.TestCase): [ ("eager", "static"), ("sdpa", "static"), - ("eager", "offloaded-static"), - ("sdpa", "offloaded-static"), ] ) def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation): @@ -427,8 +427,6 @@ class CacheIntegrationTest(unittest.TestCase): [ ("eager", "static"), ("sdpa", "static"), - ("eager", "offloaded-static"), - ("sdpa", "offloaded-static"), ] ) def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation): @@ -462,26 +460,6 @@ class CacheIntegrationTest(unittest.TestCase): with self.subTest(f"{attn_implementation}, static, eager"): self.assertListEqual(decoded, EXPECTED_GENERATION) - set_seed(0) - model._forward = model.forward - compiled_forward = torch.compile(model.forward) - - def compiled(func, input_ids, **kwargs): - return func(input_ids, **kwargs) - - def call(input_ids, **kwargs): - if input_ids.shape[-1] == 1: - return compiled(compiled_forward, input_ids, **kwargs) - - return model._forward(input_ids, **kwargs) - - model.forward = call - - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - with self.subTest(f"{attn_implementation}, static, compiled"): - self.assertListEqual(decoded, EXPECTED_GENERATION) - def test_dynamic_cache_extra_left_padding(self): """Tests that adding extra left-padding does not affect the generation with the dynamic cache""" EXPECTED_GENERATION = [ @@ -519,7 +497,6 @@ class CacheIntegrationTest(unittest.TestCase): @parameterized.expand( [ "static", - "offloaded-static", ] ) def test_static_cache_extra_left_padding(self, cache_implementation):