[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)
* support * Update modeling_utils.py * style * most models * Other models * fix-copies * tests + generation utils
This commit is contained in:
@@ -2029,10 +2029,10 @@ class GenerationTesterMixin:
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_methods_with_num_logits_to_keep(self):
|
||||
def test_generate_methods_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
config.use_cache = True
|
||||
@@ -2047,17 +2047,17 @@ class GenerationTesterMixin:
|
||||
"do_sample": False,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
# Setting logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0)
|
||||
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
def test_assisted_decoding_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support assisted generation")
|
||||
|
||||
@@ -2081,9 +2081,9 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
}
|
||||
|
||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
|
||||
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
# Setting logits_to_keep at 0 keeps all logits (old behavior)
|
||||
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, logits_to_keep=0)
|
||||
# By default, logits_to_keep is automatically set to 1 if not provided (new behavior)
|
||||
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
|
||||
|
||||
self._check_similar_generate_outputs(with_all_logits, without_all_logits)
|
||||
|
||||
@@ -531,7 +531,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits
|
||||
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||
[
|
||||
|
||||
@@ -4759,21 +4759,21 @@ class ModelTesterMixin:
|
||||
for name, param in model._orig_mod.named_parameters():
|
||||
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_forward_with_num_logits_to_keep(self):
|
||||
def test_forward_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
# some models have labels but `num_logits_to_keep` should not be used in train mode
|
||||
# some models have labels but `logits_to_keep` should not be used in train mode
|
||||
_ = inputs.pop("labels", None)
|
||||
|
||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
all_logits = model(**inputs, num_logits_to_keep=0).logits
|
||||
last_token_logits = model(**inputs, num_logits_to_keep=1).logits
|
||||
# logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
all_logits = model(**inputs, logits_to_keep=0).logits
|
||||
last_token_logits = model(**inputs, logits_to_keep=1).logits
|
||||
|
||||
# Assert all shapes are correct
|
||||
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
|
||||
|
||||
@@ -17,10 +17,15 @@ import warnings
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import __version__
|
||||
from transformers import __version__, is_torch_available
|
||||
from transformers.testing_utils import require_torch_gpu
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
INFINITE_VERSION = "9999.0.0"
|
||||
|
||||
|
||||
@@ -168,3 +173,23 @@ class DeprecationDecoratorTester(unittest.TestCase):
|
||||
with self.assertWarns(FutureWarning):
|
||||
result = dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||
self.assertEqual(result, "new_value")
|
||||
|
||||
@require_torch_gpu
|
||||
def test_compile_safe(self):
|
||||
@deprecate_kwarg("deprecated_factor", new_name="new_factor", version=INFINITE_VERSION)
|
||||
def dummy_function(new_factor=None, **kwargs):
|
||||
return new_factor * torch.ones(1, device="cuda")
|
||||
|
||||
compiled_function = torch.compile(dummy_function, fullgraph=True)
|
||||
|
||||
# Check that we can correctly call the compiled function with the old name, without raising errors
|
||||
out = compiled_function(deprecated_factor=2)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Check that we can correctly call the compiled function with the new name, without raising errors
|
||||
out = compiled_function(new_factor=2)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Check that we can correctly call the compiled function with both names, without raising errors
|
||||
out = compiled_function(new_factor=2, deprecated_factor=10)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
Reference in New Issue
Block a user