Fix the initialization of the cache when we have multi gpu (#33303)
* init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -3444,6 +3444,91 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_gpu
|
||||
def test_generate_with_static_cache_multi_gpu(self):
|
||||
"""
|
||||
Tests if the static cache has been set correctly and if generate works correctly when we are using multi-gpus.
|
||||
"""
|
||||
# need to split manually as auto doesn't work well with unbalanced model
|
||||
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 20,
|
||||
"cache_implementation": "static",
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
results = model.generate(input_ids, **generation_kwargs)
|
||||
self.assertTrue(isinstance(results.past_key_values, StaticCache))
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch_multi_gpu
|
||||
def test_init_static_cache_multi_gpu(self):
|
||||
"""
|
||||
Tests if the static cache has been set correctly when we initialize it manually in a multi-gpu setup.
|
||||
"""
|
||||
# need to split manually as auto doesn't work well with unbalanced model
|
||||
device_map = {"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, "model.norm": 1, "lm_head": 0}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map=device_map
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 20,
|
||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||
}
|
||||
|
||||
# TODO: We need to raise a warning in case the cache is not set correctly
|
||||
# with self.assertRaisesRegex(ValueError, "If you are manually initializing the cache"):
|
||||
# past_key_values = StaticCache(
|
||||
# config=model.config, batch_size=1, max_cache_len=30, device=torch_device, dtype=model.dtype
|
||||
# )
|
||||
# results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
# deduced from the device_map : layer 0 on device 0 and layer 1 on device 1
|
||||
layer_device_map = {0: 0, 1: 1}
|
||||
past_key_values = StaticCache(
|
||||
config=model.config,
|
||||
batch_size=1,
|
||||
max_cache_len=30,
|
||||
device=torch_device,
|
||||
dtype=model.dtype,
|
||||
layer_device_map=layer_device_map,
|
||||
)
|
||||
results = model.generate(input_ids, past_key_values=past_key_values, **generation_kwargs)
|
||||
|
||||
# check device of each layer
|
||||
key_cache_0 = results.past_key_values.key_cache[0]
|
||||
value_cache_0 = results.past_key_values.value_cache[0]
|
||||
self.assertTrue(key_cache_0.device == value_cache_0.device == torch.device(0))
|
||||
|
||||
key_cache_1 = results.past_key_values.key_cache[1]
|
||||
value_cache_1 = results.past_key_values.value_cache[1]
|
||||
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user