From af37d183b30f0e4430e479aab509ab0e7cb553c9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 20 Jan 2023 12:50:01 +0000 Subject: [PATCH] Generate: documented function to compute the transition scores (#21191) Co-authored-by: Patrick von Platen --- .../en/main_classes/text_generation.mdx | 1 + src/transformers/generation/utils.py | 107 +++++++++++++++--- tests/generation/test_utils.py | 76 ++++++++++--- 3 files changed, 154 insertions(+), 30 deletions(-) diff --git a/docs/source/en/main_classes/text_generation.mdx b/docs/source/en/main_classes/text_generation.mdx index 2a13eae950..0a796d007b 100644 --- a/docs/source/en/main_classes/text_generation.mdx +++ b/docs/source/en/main_classes/text_generation.mdx @@ -37,6 +37,7 @@ and how to create and save a customized generation configuration, refer to the [[autodoc]] generation.GenerationMixin - generate + - compute_transition_scores - greedy_search - sample - beam_search diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b47c4db3e3..efaafb8f24 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -924,42 +924,121 @@ class GenerationMixin: default_list.extend(custom_list) return default_list - def compute_transition_beam_scores( + def compute_transition_scores( self, sequences: torch.Tensor, scores: Tuple[torch.Tensor], - beam_indices: torch.Tensor, - eos_token_id: Union[int, List[int]] = None, - ): - """compute the transition probabilities of sequences given generation - scores and beam indices""" + beam_indices: Optional[torch.Tensor] = None, + normalize_logits: bool = False, + ) -> torch.Tensor: + """ + Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was + used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time. - # 1. reshape scores as [vocab_size * batch_size, # generation steps] - # with batch_size being 2 * vocab_size and # generation steps being + Parameters: + sequences (`torch.LongTensor`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or + shorter if all batches finished early due to the `eos_token_id`. + scores (`tuple(torch.FloatTensor)`): + Transition scores for each vocabulary token at each generation step. Beam transition scores consisting + of log probabilities of tokens conditioned on log softmax of previously generated tokens Tuple of + `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with + each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*): + Beam indices of generated token id at each generation step. `torch.LongTensor` of shape + `(batch_size*num_return_sequences, input_ids.shape[-1])`. Only required if a `num_beams>1` at + generate-time. + normalize_logits (`bool`, *optional*, defaults to `False`): + Whether to normalize the logits (which, for legacy reasons, may be unnormalized). + + Return: + `torch.Tensor`: A `torch.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing + the transition scores (logits) + + Examples: + + ```python + >>> from transformers import GPT2Tokenizer, AutoModelForCausalLM + >>> import numpy as np + + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer.pad_token_id = tokenizer.eos_token_id + >>> inputs = tokenizer(["Today is"], return_tensors="pt") + + >>> # Example 1: Print the scores for each token generated with Greedy Search + >>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True) + >>> transition_scores = model.compute_transition_scores( + ... outputs.sequences, outputs.scores, normalize_logits=True + ... ) + >>> input_length = inputs.input_ids.shape[1] + >>> generated_tokens = outputs.sequences[:, input_length:] + >>> for tok, score in zip(generated_tokens[0], transition_scores[0]): + ... # | token | token string | logits | probability + ... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.4f} | {np.exp(score.numpy()):.2%}") + | 262 | the | -1.4136 | 24.33% + | 1110 | day | -2.6089 | 7.36% + | 618 | when | -2.0096 | 13.40% + | 356 | we | -1.8593 | 15.58% + | 460 | can | -2.5083 | 8.14% + + >>> # Example 2: Reconstruct the sequence scores from Beam Search + >>> outputs = model.generate( + ... **inputs, + ... max_new_tokens=5, + ... num_beams=4, + ... num_return_sequences=4, + ... return_dict_in_generate=True, + ... output_scores=True, + ... ) + >>> transition_scores = model.compute_transition_scores( + ... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False + ... ) + >>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores. + >>> # Tip: set `normalize_logits=True` to recompute the scores from the normalized logits. + >>> output_length = inputs.input_ids.shape[1] + np.sum(transition_scores.numpy() < 0, axis=1) + >>> length_penalty = model.generation_config.length_penalty + >>> reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) + >>> print(np.allclose(outputs.sequences_scores, reconstructed_scores)) + True + ```""" + # 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent + # to a beam search approach were the first (and only) beam is always selected + if beam_indices is None: + beam_indices = torch.arange(scores[0].shape[0]).view(-1, 1).to(sequences.device) + beam_indices = beam_indices.expand(-1, len(scores)) + + # 2. reshape scores as [batch_size*vocab_size, # generation steps] with # generation steps being # seq_len - input_length scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) - # 2. cut beam_indices to longest beam length + # 3. Optionally normalize the logits (across the vocab dimension) + if normalize_logits: + scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1]) + scores = torch.nn.functional.log_softmax(scores, dim=1) + scores = scores.reshape(-1, scores.shape[-1]) + + # 4. cut beam_indices to longest beam length beam_indices_mask = beam_indices < 0 max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() beam_indices = beam_indices[:, :max_beam_length] beam_indices_mask = beam_indices_mask[:, :max_beam_length] - # 3. Set indices of beams that finished early to 0 + # 5. Set indices of beams that finished early to 0 # such indices will be masked correctly afterwards beam_indices[beam_indices_mask] = 0 - # 4. multiply beam_indices with vocab size to gather correctly from scores + # 6. multiply beam_indices with vocab size to gather correctly from scores beam_sequence_indices = beam_indices * self.config.vocab_size - # 5. Define which indices contributed to scores + # 7. Define which indices contributed to scores cut_idx = sequences.shape[-1] - max_beam_length indices = sequences[:, cut_idx:] + beam_sequence_indices - # 6. Compute scores + # 8. Compute scores transition_scores = scores.gather(0, indices) - # 7. Mask out transition_scores of beams that stopped early + # 9. Mask out transition_scores of beams that stopped early transition_scores[beam_indices_mask] = 0 return transition_scores diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index aeb2bf480b..3339b60091 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -17,6 +17,8 @@ import inspect import unittest +import numpy as np + from transformers import is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device @@ -2485,6 +2487,58 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist()) + def test_transition_scores_greedy_search(self): + articles = ["Justin Timberlake", "Michael Phelps"] + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + tokenizer.pad_token = tokenizer.eos_token + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) + outputs = model.generate( + input_ids=input_ids, + max_new_tokens=5, + pad_token_id=tokenizer.eos_token_id, + eos_token_id=None, + return_dict_in_generate=True, + output_scores=True, + ) + + transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores) + expected_scores = np.array( + [ + [0.3596273, 0.39646253, 0.46157718, 0.4594633, 0.44866616], + [0.34934354, 0.4935004, 0.6373219, 0.5173545, 0.57517034], + ] + ) + self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores)) + + def test_transition_scores_greedy_search_normalized(self): + articles = ["Justin Timberlake", "Michael Phelps"] + tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + tokenizer.pad_token = tokenizer.eos_token + + model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) + outputs = model.generate( + input_ids=input_ids, + max_new_tokens=5, + pad_token_id=tokenizer.eos_token_id, + eos_token_id=None, + return_dict_in_generate=True, + output_scores=True, + ) + + transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True) + expected_scores = np.array( + [ + [-6.5532393, -6.5158753, -6.451863, -6.4527144, -6.459402], + [-6.5685124, -6.4277077, -6.282607, -6.399295, -6.340927], + ] + ) + self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores)) + def test_transition_scores_beam_search_encoder_decoder(self): articles = [ "Justin Timberlake and Jessica Biel, welcome to parenthood.", @@ -2506,9 +2560,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) outputs = model.generate(input_ids=input_ids) - transition_scores = model.compute_transition_beam_scores( - outputs.sequences, outputs.scores, outputs.beam_indices - ) + transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) transition_scores_sum = transition_scores.sum(-1) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) @@ -2533,9 +2585,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) outputs = model.generate(input_ids=input_ids) - transition_scores = model.compute_transition_beam_scores( - outputs.sequences, outputs.scores, outputs.beam_indices - ) + transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) transition_scores_sum = transition_scores.sum(-1) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) @@ -2564,9 +2614,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) outputs = model.generate(input_ids=input_ids) - transition_scores = model.compute_transition_beam_scores( - outputs.sequences, outputs.scores, outputs.beam_indices - ) + transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) transition_scores_sum = transition_scores.sum(-1) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) @@ -2593,9 +2641,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) outputs = model.generate(input_ids=input_ids) - transition_scores = model.compute_transition_beam_scores( - outputs.sequences, outputs.scores, outputs.beam_indices - ) + transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) transition_scores_sum = transition_scores.sum(-1) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) @@ -2622,9 +2668,7 @@ class GenerationIntegrationTests(unittest.TestCase): input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device) outputs = model.generate(input_ids=input_ids) - transition_scores = model.compute_transition_beam_scores( - outputs.sequences, outputs.scores, outputs.beam_indices - ) + transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices) transition_scores_sum = transition_scores.sum(-1) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) @@ -2653,7 +2697,7 @@ class GenerationIntegrationTests(unittest.TestCase): length_penalty=0.0, ) - transition_scores = model.compute_transition_beam_scores( + transition_scores = model.compute_transition_scores( sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices )