ENH: added new output_logits option to generate function (#28667)
output_logits option behaves like output_scores, but returns the raw, unprocessed prediction logit scores, ie. the values before they undergo logit processing and/or warping. The latter happens by default for the regular output scores. It's useful to have the unprocessed logit scores in certain circumstances. For example, unprocessed logit scores are very useful with causallm models when one wants to determine the probability of a certain answer, e.g. when asking a question with a yes/no answer. In that case getting the next-token probabilities of both "yes" and "no" (and/or their relative ratio) is of interest for classification. The reason for getting these _before_ logit processing and/or warping is b/c a) that can change the probabilities or b) reject the tokens of interest / reduce the number of tokens to just 1. For an example use-case see paper TabLLM: Few-shot Classification of Tabular Data with Large Language Models by Stefan Hegselmann, Alejandro Buendia, Hunter Lang, Monica Agrawal, Xiaoyi Jiang, and David Sontag. https://arxiv.org/abs/2210.10723 In addition: - added dedicated unit test: tests/generation/test_utils/test_return_unprocessed_logit_scores which tests return of logics with output_logits=True in generation. - set output_logits=True in all other generation unit tests, that also have output_scores=True. Implemented @gante's and @amyeroberts review feedback Co-authored-by: kx79wq <max.baak@ing.com>
This commit is contained in:
@@ -269,6 +269,7 @@ class GenerationTesterMixin:
|
||||
attention_mask,
|
||||
max_length,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
@@ -293,6 +294,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -317,6 +319,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
@@ -335,6 +338,7 @@ class GenerationTesterMixin:
|
||||
logits_warper_kwargs,
|
||||
process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
@@ -348,6 +352,7 @@ class GenerationTesterMixin:
|
||||
max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -379,6 +384,7 @@ class GenerationTesterMixin:
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -399,6 +405,7 @@ class GenerationTesterMixin:
|
||||
logits_processor,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
@@ -409,6 +416,7 @@ class GenerationTesterMixin:
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -440,6 +448,7 @@ class GenerationTesterMixin:
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -459,6 +468,7 @@ class GenerationTesterMixin:
|
||||
logits_warper,
|
||||
logits_warper_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
@@ -470,6 +480,7 @@ class GenerationTesterMixin:
|
||||
do_sample=True,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -506,6 +517,7 @@ class GenerationTesterMixin:
|
||||
logits_warper=logits_warper,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -526,6 +538,7 @@ class GenerationTesterMixin:
|
||||
logits_processor,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
@@ -536,6 +549,7 @@ class GenerationTesterMixin:
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -567,6 +581,7 @@ class GenerationTesterMixin:
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -587,6 +602,7 @@ class GenerationTesterMixin:
|
||||
logits_processor,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
@@ -597,6 +613,7 @@ class GenerationTesterMixin:
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -629,6 +646,7 @@ class GenerationTesterMixin:
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
@@ -644,6 +662,7 @@ class GenerationTesterMixin:
|
||||
attention_mask,
|
||||
max_length,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
@@ -673,6 +692,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_process_kwargs,
|
||||
**model_kwargs,
|
||||
@@ -699,6 +719,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
@@ -729,6 +750,7 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -769,6 +791,7 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -853,6 +876,7 @@ class GenerationTesterMixin:
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -964,6 +988,7 @@ class GenerationTesterMixin:
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -1032,6 +1057,7 @@ class GenerationTesterMixin:
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
logits_processor=logits_processor,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -1126,6 +1152,7 @@ class GenerationTesterMixin:
|
||||
logits_warper=logits_warper,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -1262,6 +1289,7 @@ class GenerationTesterMixin:
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -1421,6 +1449,7 @@ class GenerationTesterMixin:
|
||||
logits_processor=logits_processor,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -1493,6 +1522,7 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
max_length=max_length,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
@@ -1628,6 +1658,7 @@ class GenerationTesterMixin:
|
||||
"num_beams": 1,
|
||||
"do_sample": False,
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"return_dict_in_generate": True,
|
||||
@@ -1690,6 +1721,7 @@ class GenerationTesterMixin:
|
||||
"num_beams": 1,
|
||||
"do_sample": False,
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"return_dict_in_generate": True,
|
||||
@@ -1753,6 +1785,7 @@ class GenerationTesterMixin:
|
||||
"do_sample": True,
|
||||
"assistant_model": assistant_model,
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"return_dict_in_generate": True,
|
||||
@@ -2105,6 +2138,7 @@ class GenerationTesterMixin:
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
@@ -2112,6 +2146,9 @@ class GenerationTesterMixin:
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
|
||||
# unprocessed logits
|
||||
self._check_logits(num_sequences_in_output, output.logits, config=config)
|
||||
|
||||
# Attentions
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
@@ -2191,6 +2228,14 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(len(scores), length)
|
||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||
|
||||
def _check_logits(self, batch_size, scores, config):
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
|
||||
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
||||
vocab_diff = config.vocab_size - scores[0].shape[-1]
|
||||
self.assertTrue(vocab_diff in [0, 1])
|
||||
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
@@ -3536,3 +3581,60 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
model.generate(**inputs, **generation_kwargs)
|
||||
# update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
|
||||
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)
|
||||
|
||||
def test_compare_unprocessed_logit_scores(self):
|
||||
# Get unprocessed logit scores back from model generate function.
|
||||
# Assert that unprocessed logits from generate() are same as those from modal eval()
|
||||
|
||||
# tell model to generate text and return unprocessed/unwarped logit scores
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = "generate yes or no: "
|
||||
input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Get logits for the next token from fwd pass
|
||||
logits_fwd = model(input_ids).logits[:, -1, :][0]
|
||||
|
||||
# Get logits for the next token from generate function
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
return_dict_in_generate=True,
|
||||
output_logits=True,
|
||||
max_new_tokens=1,
|
||||
do_sample=True,
|
||||
)
|
||||
logits_gen = outputs.logits[0][0]
|
||||
|
||||
# assert that unprocessed logits from generate() are same as those from modal eval()
|
||||
self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist())
|
||||
|
||||
def test_return_unprocessed_logit_scores(self):
|
||||
# tell model to generate text and return unprocessed/unwarped logit scores
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = "generate yes or no: "
|
||||
input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3
|
||||
)
|
||||
|
||||
# perform dummy check if unpreprocessed logits make sense.
|
||||
# do preselection on high probabilities; find scores of y and n tokens
|
||||
probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1)
|
||||
indices = torch.argwhere(probs_all > 0.001)
|
||||
indices = indices[:, -1]
|
||||
tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True)
|
||||
probs_max = probs_all[probs_all > 0.001]
|
||||
|
||||
self.assertTrue(len(indices) >= 2)
|
||||
next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)}
|
||||
self.assertTrue("n" in next_token_dict)
|
||||
self.assertTrue("y" in next_token_dict)
|
||||
y_prob = next_token_dict["y"]
|
||||
n_prob = next_token_dict["n"]
|
||||
|
||||
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
||||
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
||||
|
||||
Reference in New Issue
Block a user