[Backend support] Allow num_logits_to_keep as Tensor + add flag (#35757)
* support * Update modeling_utils.py * style * most models * Other models * fix-copies * tests + generation utils
This commit is contained in:
@@ -4759,21 +4759,21 @@ class ModelTesterMixin:
|
||||
for name, param in model._orig_mod.named_parameters():
|
||||
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_forward_with_num_logits_to_keep(self):
|
||||
def test_forward_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
|
||||
if "logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||
self.skipTest(reason="This model does not support `logits_to_keep` argument.")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
# some models have labels but `num_logits_to_keep` should not be used in train mode
|
||||
# some models have labels but `logits_to_keep` should not be used in train mode
|
||||
_ = inputs.pop("labels", None)
|
||||
|
||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
all_logits = model(**inputs, num_logits_to_keep=0).logits
|
||||
last_token_logits = model(**inputs, num_logits_to_keep=1).logits
|
||||
# logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
all_logits = model(**inputs, logits_to_keep=0).logits
|
||||
last_token_logits = model(**inputs, logits_to_keep=1).logits
|
||||
|
||||
# Assert all shapes are correct
|
||||
self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size))
|
||||
|
||||
Reference in New Issue
Block a user