Reducing memory usage: removing useless logits computation in generate() (#31292)

* Add .float() in all generation methods logit outputs

* Switch float-casting of logits to training only for main models

* Add `num_logits_to_keep` in Llama and add it by default in generate

* Apply style

* Add num_logits_to_keep as arg in prepare_input_for_generation

* Add support for Mistral

* Revert models except llama and mistral

* Fix default None value in _supports_num_logits_to_keep()

* Fix dimension of dummy input

* Add exception for prophetnet in _supports_num_logits_to_keep()

* Update _supports_num_logits_to_keep() to use inspect.signature()

* Add deprecation cycle + remove modification with pretraining_tp

* Apply style

* Add most used models

* Apply style

* Make `num_logits_to_keep` an int in all cases to remove if-else clause

* Add compile check for the warning

* Fix torch versions

* style

* Add gemma2

* Update warning version

* Add comment about .float operations in generation utils

* Add tests in GenerationTesterMixin and ModelTesterMixin

* Fix batch size for assisted decoding in tests

* fix small issues in test

* refacor test

* fix slicing removing dim issue

* Add nemotron support (should fix check-copy issue in CIs)

* Trigger new CIs

* Trigger new CIs

* Bump version

* Bump version in TODO

* Trigger CIs

* remove blank space

* Trigger CIs
This commit is contained in:
Cyril Vallez
2024-08-23 12:08:34 +02:00
committed by GitHub
parent d806fa3e92
commit 22e6f14525
23 changed files with 428 additions and 41 deletions

View File

@@ -4824,6 +4824,27 @@ class ModelTesterMixin:
self.assertTrue(record_time < 0.15 * graph_warmup_time)
self.assertTrue(opt_time < record_time)
def test_forward_with_num_logits_to_keep(self):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, sequence_length = inputs["input_ids"].shape
vocab_size = config.vocab_size
model = model_class(config).to(device=torch_device).eval()
# num_logits_to_keep=0 is a special case meaning "keep all logits"
all_logits = model(**inputs, num_logits_to_keep=0).logits
last_token_logits = model(**inputs, num_logits_to_keep=1).logits
# Assert all shapes are correct
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size))
# Assert the last tokens are actually the same
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits))
global_rng = random.Random()