[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)
* support * Update modeling_utils.py * style * most models * Other models * fix-copies * tests + generation utils
This commit is contained in:
@@ -17,10 +17,15 @@ import warnings
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import __version__
|
||||
from transformers import __version__, is_torch_available
|
||||
from transformers.testing_utils import require_torch_gpu
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
INFINITE_VERSION = "9999.0.0"
|
||||
|
||||
|
||||
@@ -168,3 +173,23 @@ class DeprecationDecoratorTester(unittest.TestCase):
|
||||
with self.assertWarns(FutureWarning):
|
||||
result = dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||
self.assertEqual(result, "new_value")
|
||||
|
||||
@require_torch_gpu
|
||||
def test_compile_safe(self):
|
||||
@deprecate_kwarg("deprecated_factor", new_name="new_factor", version=INFINITE_VERSION)
|
||||
def dummy_function(new_factor=None, **kwargs):
|
||||
return new_factor * torch.ones(1, device="cuda")
|
||||
|
||||
compiled_function = torch.compile(dummy_function, fullgraph=True)
|
||||
|
||||
# Check that we can correctly call the compiled function with the old name, without raising errors
|
||||
out = compiled_function(deprecated_factor=2)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Check that we can correctly call the compiled function with the new name, without raising errors
|
||||
out = compiled_function(new_factor=2)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Check that we can correctly call the compiled function with both names, without raising errors
|
||||
out = compiled_function(new_factor=2, deprecated_factor=10)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
Reference in New Issue
Block a user