Reducing memory usage: removing useless logits computation in generate() (#31292)

* Add .float() in all generation methods logit outputs

* Switch float-casting of logits to training only for main models

* Add `num_logits_to_keep` in Llama and add it by default in generate

* Apply style

* Add num_logits_to_keep as arg in prepare_input_for_generation

* Add support for Mistral

* Revert models except llama and mistral

* Fix default None value in _supports_num_logits_to_keep()

* Fix dimension of dummy input

* Add exception for prophetnet in _supports_num_logits_to_keep()

* Update _supports_num_logits_to_keep() to use inspect.signature()

* Add deprecation cycle + remove modification with pretraining_tp

* Apply style

* Add most used models

* Apply style

* Make `num_logits_to_keep` an int in all cases to remove if-else clause

* Add compile check for the warning

* Fix torch versions

* style

* Add gemma2

* Update warning version

* Add comment about .float operations in generation utils

* Add tests in GenerationTesterMixin and ModelTesterMixin

* Fix batch size for assisted decoding in tests

* fix small issues in test

* refacor test

* fix slicing removing dim issue

* Add nemotron support (should fix check-copy issue in CIs)

* Trigger new CIs

* Trigger new CIs

* Bump version

* Bump version in TODO

* Trigger CIs

* remove blank space

* Trigger CIs
This commit is contained in:
Cyril Vallez
2024-08-23 12:08:34 +02:00
committed by GitHub
parent d806fa3e92
commit 22e6f14525
23 changed files with 428 additions and 41 deletions

View File

@@ -1828,6 +1828,62 @@ class GenerationTesterMixin:
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
def test_generate_methods_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
# other methods will work as well)
generation_kwargs = {
"max_new_tokens": 10,
"do_sample": False,
}
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
def test_assisted_decoding_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
assistant_model = model
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
# other methods will work as well)
generation_kwargs = {
"max_new_tokens": 10,
"do_sample": False,
"assistant_model": assistant_model,
}
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences