[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)

* support

* Update modeling_utils.py

* style

* most models

* Other models

* fix-copies

* tests + generation utils
This commit is contained in:
Cyril Vallez
2025-01-23 09:47:54 +01:00
committed by GitHub
parent 8736e91ad6
commit d3af76df58
62 changed files with 603 additions and 315 deletions

View File

@@ -531,7 +531,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
if self.cuda_compute_capability_major_version == 8:
with torch.no_grad():
logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[