[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)
* support * Update modeling_utils.py * style * most models * Other models * fix-copies * tests + generation utils
This commit is contained in:
@@ -2029,10 +2029,10 @@ class GenerationTesterMixin:
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_methods_with_num_logits_to_keep(self):
|
||||
def test_generate_methods_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
config.use_cache = True
|
||||
@@ -2047,17 +2047,17 @@ class GenerationTesterMixin:
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
# Setting logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0)
|
||||
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
def test_assisted_decoding_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
@@ -2081,9 +2081,9 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
# Setting logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0)
|
||||
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||
|
||||
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
|
||||
|
||||
Reference in New Issue
Block a user