Reducing memory usage: removing useless logits computation in generate() (#31292)
* Add .float() in all generation methods logit outputs * Switch float-casting of logits to training only for main models * Add `num_logits_to_keep` in Llama and add it by default in generate * Apply style * Add num_logits_to_keep as arg in prepare_input_for_generation * Add support for Mistral * Revert models except llama and mistral * Fix default None value in _supports_num_logits_to_keep() * Fix dimension of dummy input * Add exception for prophetnet in _supports_num_logits_to_keep() * Update _supports_num_logits_to_keep() to use inspect.signature() * Add deprecation cycle + remove modification with pretraining_tp * Apply style * Add most used models * Apply style * Make `num_logits_to_keep` an int in all cases to remove if-else clause * Add compile check for the warning * Fix torch versions * style * Add gemma2 * Update warning version * Add comment about .float operations in generation utils * Add tests in GenerationTesterMixin and ModelTesterMixin * Fix batch size for assisted decoding in tests * fix small issues in test * refacor test * fix slicing removing dim issue * Add nemotron support (should fix check-copy issue in CIs) * Trigger new CIs * Trigger new CIs * Bump version * Bump version in TODO * Trigger CIs * remove blank space * Trigger CIs
This commit is contained in:
@@ -1828,6 +1828,62 @@ class GenerationTesterMixin:
|
||||
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||
|
||||
def test_generate_methods_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
|
||||
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
|
||||
# other methods will work as well)
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 10,
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
||||
)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
assistant_model = model
|
||||
# All generation methods (except assisted decoding) rely on always extracting the last token logits of the
|
||||
# full logits matrix, so testing out only greedy search and assisted decoding is enough (if it works,
|
||||
# other methods will work as well)
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 10,
|
||||
"do_sample": False,
|
||||
"assistant_model": assistant_model,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(
|
||||
input_ids, attention_mask=attention_mask, **generation_kwargs, num_logits_to_keep=0
|
||||
)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
Reference in New Issue
Block a user