Fix the initialization of the cache when we have multi gpu (#33303)

* init cache multi-gpu

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* switch to execution device map

* naming more consistant

* fix

* mutually exclusive device

* added an integration example

* remove useless check

* suggestion from joao + typing

* fix couple of typo and add test

* revert check

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Marc Sun
2024-09-13 15:06:08 +02:00
committed by GitHub
parent dfd31158ee
commit 6cc4dfe3f1
3 changed files with 141 additions and 11 deletions

View File

@@ -3444,6 +3444,91 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)
@pytest.mark.generate
@require_torch_multi_gpu
def test_generate_with_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
"""
# need to split manually as auto doesn't work well with unbalanced model
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
generation_kwargs = {
"max_new_tokens": 20,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
results = model.generate(input_ids, **generation_kwargs)
self.assertTrue(isinstance(results.past_key_values, StaticCache))
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
@pytest.mark.generate
@require_torch_multi_gpu
def test_init_static_cache_multi_gpu(self):
"""
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
"""
# need to split manually as auto doesn't work well with unbalanced model
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
generation_kwargs = {
"max_new_tokens": 20,
"return_dict_in_generate": True, # Required to return `past_key_values`
}
# TODO: We need to raise a warning in case the cache is not set correctly
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
# past_key_values = StaticCache(
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
# )
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
# deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
layer_device_map = {0: 0, 1: 1}
past_key_values = StaticCache(
config=model.config,
batch_size=1,
max_cache_len=30,
device=torch_device,
dtype=model.dtype,
layer_device_map=layer_device_map,
)
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
# check device of each layer
key_cache_0 = results.past_key_values.key_cache[0]
value_cache_0 = results.past_key_values.value_cache[0]
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
key_cache_1 = results.past_key_values.key_cache[1]
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
@require_torch
class TokenHealingTestCase(unittest.TestCase):