Reducing memory usage: removing useless logits computation in generate() (#31292)
* Add .float() in all generation methods logit outputs * Switch float-casting of logits to training only for main models * Add `num_logits_to_keep` in Llama and add it by default in generate * Apply style * Add num_logits_to_keep as arg in prepare_input_for_generation * Add support for Mistral * Revert models except llama and mistral * Fix default None value in _supports_num_logits_to_keep() * Fix dimension of dummy input * Add exception for prophetnet in _supports_num_logits_to_keep() * Update _supports_num_logits_to_keep() to use inspect.signature() * Add deprecation cycle + remove modification with pretraining_tp * Apply style * Add most used models * Apply style * Make `num_logits_to_keep` an int in all cases to remove if-else clause * Add compile check for the warning * Fix torch versions * style * Add gemma2 * Update warning version * Add comment about .float operations in generation utils * Add tests in GenerationTesterMixin and ModelTesterMixin * Fix batch size for assisted decoding in tests * fix small issues in test * refacor test * fix slicing removing dim issue * Add nemotron support (should fix check-copy issue in CIs) * Trigger new CIs * Trigger new CIs * Bump version * Bump version in TODO * Trigger CIs * remove blank space * Trigger CIs
This commit is contained in:
@@ -4824,6 +4824,27 @@ class ModelTesterMixin:
|
||||
self.assertTrue(record_time < 0.15 * graph_warmup_time)
|
||||
self.assertTrue(opt_time < record_time)
|
||||
|
||||
def test_forward_with_num_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
vocab_size = config.vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
|
||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
all_logits = model(**inputs, num_logits_to_keep=0).logits
|
||||
last_token_logits = model(**inputs, num_logits_to_keep=1).logits
|
||||
|
||||
# Assert all shapes are correct
|
||||
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
|
||||
self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size))
|
||||
|
||||
# Assert the last tokens are actually the same
|
||||
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user