[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)
* support * Update modeling_utils.py * style * most models * Other models * fix-copies * tests + generation utils
This commit is contained in:
@@ -531,7 +531,7 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids=input_ids, num_logits_to_keep=40).logits
|
||||
logits = self.model(input_ids=input_ids, logits_to_keep=40).logits
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user