Fix the initialization of the cache when we have multi gpu (#33303)
* init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -1030,6 +1030,9 @@ class StaticCache(Cache):
|
|||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||||
|
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||||
|
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -1060,6 +1063,7 @@ class StaticCache(Cache):
|
|||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if max_batch_size is not None:
|
if max_batch_size is not None:
|
||||||
@@ -1088,16 +1092,20 @@ class StaticCache(Cache):
|
|||||||
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
|
||||||
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||||
for idx in range(config.num_hidden_layers):
|
for idx in range(config.num_hidden_layers):
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
if layer_device_map is not None:
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
layer_device = layer_device_map[idx]
|
||||||
|
else:
|
||||||
|
layer_device = device
|
||||||
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
# Notes:
|
# Notes:
|
||||||
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
||||||
# it is not needed anyway)
|
# it is not needed anyway)
|
||||||
# 2. `torch.export()` requires mutations to be registered as buffers.
|
# 2. `torch.export()` requires mutations to be registered as buffers.
|
||||||
if not is_torchdynamo_compiling():
|
if not is_torchdynamo_compiling():
|
||||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
|
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||||
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
||||||
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
@@ -1130,9 +1138,9 @@ class StaticCache(Cache):
|
|||||||
Return:
|
Return:
|
||||||
A tuple containing the updated key and value states.
|
A tuple containing the updated key and value states.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
|
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
|
|
||||||
@@ -1201,6 +1209,9 @@ class SlidingWindowCache(StaticCache):
|
|||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||||
|
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||||
|
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -1231,6 +1242,7 @@ class SlidingWindowCache(StaticCache):
|
|||||||
device: torch.device = None,
|
device: torch.device = None,
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||||
@@ -1247,6 +1259,7 @@ class SlidingWindowCache(StaticCache):
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
|
layer_device_map=layer_device_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
@@ -1280,7 +1293,6 @@ class SlidingWindowCache(StaticCache):
|
|||||||
v_out = v_out[:, :, indices]
|
v_out = v_out[:, :, indices]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cache_position.to(device=k_out.device)
|
|
||||||
k_out.index_copy_(2, cache_position, key_states)
|
k_out.index_copy_(2, cache_position, key_states)
|
||||||
v_out.index_copy_(2, cache_position, value_states)
|
v_out.index_copy_(2, cache_position, value_states)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
@@ -1495,6 +1507,9 @@ class HybridCache(Cache):
|
|||||||
The device on which the cache should be initialized. Should be the same as the layer.
|
The device on which the cache should be initialized. Should be the same as the layer.
|
||||||
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
|
||||||
The default `dtype` to use when initializing the layer.
|
The default `dtype` to use when initializing the layer.
|
||||||
|
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
|
||||||
|
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
|
||||||
|
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -1525,6 +1540,7 @@ class HybridCache(Cache):
|
|||||||
device: Union[torch.device, str] = "cpu",
|
device: Union[torch.device, str] = "cpu",
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if max_batch_size is not None:
|
if max_batch_size is not None:
|
||||||
@@ -1562,11 +1578,15 @@ class HybridCache(Cache):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
|
if layer_device_map is not None:
|
||||||
|
layer_device = layer_device_map[i]
|
||||||
|
else:
|
||||||
|
layer_device = device
|
||||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||||
# breaks when updating the cache.
|
# breaks when updating the cache.
|
||||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||||
self.key_cache.append(new_layer_key_cache)
|
self.key_cache.append(new_layer_key_cache)
|
||||||
@@ -1617,8 +1637,6 @@ class HybridCache(Cache):
|
|||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
cache_position = cache_kwargs.get("cache_position")
|
cache_position = cache_kwargs.get("cache_position")
|
||||||
sliding_window = cache_kwargs.get("sliding_window")
|
sliding_window = cache_kwargs.get("sliding_window")
|
||||||
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
|
|
||||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
|
|
||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
if sliding_window:
|
if sliding_window:
|
||||||
|
|||||||
@@ -1446,12 +1446,39 @@ class GenerationMixin:
|
|||||||
# models. May cause trobles with non-text modalities.
|
# models. May cause trobles with non-text modalities.
|
||||||
cache_dtype = self.get_output_embeddings().weight.dtype
|
cache_dtype = self.get_output_embeddings().weight.dtype
|
||||||
|
|
||||||
|
def get_layer_device_map(execution_device_map: Optional[dict] = None):
|
||||||
|
if execution_device_map is None or len(execution_device_map) <= 1:
|
||||||
|
return None
|
||||||
|
layer_device_map = {}
|
||||||
|
for layer in execution_device_map:
|
||||||
|
for idx in range(self.config.num_hidden_layers):
|
||||||
|
if f".{idx}." in f"{layer}.":
|
||||||
|
layer_device_map[idx] = execution_device_map[layer]
|
||||||
|
break
|
||||||
|
for idx in range(self.config.num_hidden_layers):
|
||||||
|
if idx not in layer_device_map:
|
||||||
|
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
|
||||||
|
return layer_device_map
|
||||||
|
|
||||||
|
execution_device_map = None
|
||||||
|
# Taken from dispatch_model from accelerate.
|
||||||
|
# This is needed here if we don't want to make changes in accelerate in order to save execution_device
|
||||||
|
# For offloaded case, we need to get the execution device, not just the device where it is offloaded
|
||||||
|
if hasattr(self, "hf_device_map"):
|
||||||
|
main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
|
||||||
|
execution_device_map = {
|
||||||
|
name: main_device if device in ["cpu", "disk"] else device
|
||||||
|
for name, device in self.hf_device_map.items()
|
||||||
|
}
|
||||||
|
layer_device_map = get_layer_device_map(execution_device_map)
|
||||||
|
|
||||||
cache_kwargs = {
|
cache_kwargs = {
|
||||||
"config": self.config if hasattr(self.config, "text_config") else self.config,
|
"config": self.config if hasattr(self.config, "text_config") else self.config,
|
||||||
"max_batch_size": batch_size,
|
"max_batch_size": batch_size,
|
||||||
"max_cache_len": max_cache_len,
|
"max_cache_len": max_cache_len,
|
||||||
"device": device,
|
"device": device,
|
||||||
"dtype": cache_dtype,
|
"dtype": cache_dtype,
|
||||||
|
"layer_device_map": layer_device_map,
|
||||||
}
|
}
|
||||||
self._cache = cache_cls(**cache_kwargs)
|
self._cache = cache_cls(**cache_kwargs)
|
||||||
if requires_cross_attention_cache:
|
if requires_cross_attention_cache:
|
||||||
|
|||||||
@@ -3444,6 +3444,91 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||||
self.assertTrue(generation_config.bos_token_id is None)
|
self.assertTrue(generation_config.bos_token_id is None)
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_generate_with_static_cache_multi_gpu(self):
|
||||||
|
"""
|
||||||
|
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
|
||||||
|
"""
|
||||||
|
# need to split manually as auto doesn't work well with unbalanced model
|
||||||
|
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||||
|
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"max_new_tokens": 20,
|
||||||
|
"cache_implementation": "static",
|
||||||
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
}
|
||||||
|
|
||||||
|
results = model.generate(input_ids, **generation_kwargs)
|
||||||
|
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||||
|
|
||||||
|
# check device of each layer
|
||||||
|
key_cache_0 = results.past_key_values.key_cache[0]
|
||||||
|
value_cache_0 = results.past_key_values.value_cache[0]
|
||||||
|
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||||
|
|
||||||
|
key_cache_1 = results.past_key_values.key_cache[1]
|
||||||
|
value_cache_1 = results.past_key_values.value_cache[1]
|
||||||
|
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_init_static_cache_multi_gpu(self):
|
||||||
|
"""
|
||||||
|
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
|
||||||
|
"""
|
||||||
|
# need to split manually as auto doesn't work well with unbalanced model
|
||||||
|
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||||
|
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"max_new_tokens": 20,
|
||||||
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: We need to raise a warning in case the cache is not set correctly
|
||||||
|
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
|
||||||
|
# past_key_values = StaticCache(
|
||||||
|
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
|
||||||
|
# )
|
||||||
|
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||||
|
|
||||||
|
# deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
|
||||||
|
layer_device_map = {0: 0, 1: 1}
|
||||||
|
past_key_values = StaticCache(
|
||||||
|
config=model.config,
|
||||||
|
batch_size=1,
|
||||||
|
max_cache_len=30,
|
||||||
|
device=torch_device,
|
||||||
|
dtype=model.dtype,
|
||||||
|
layer_device_map=layer_device_map,
|
||||||
|
)
|
||||||
|
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||||
|
|
||||||
|
# check device of each layer
|
||||||
|
key_cache_0 = results.past_key_values.key_cache[0]
|
||||||
|
value_cache_0 = results.past_key_values.value_cache[0]
|
||||||
|
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||||
|
|
||||||
|
key_cache_1 = results.past_key_values.key_cache[1]
|
||||||
|
value_cache_1 = results.past_key_values.value_cache[1]
|
||||||
|
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TokenHealingTestCase(unittest.TestCase):
|
class TokenHealingTestCase(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user