[generate] beam search -- fix output cropping (#37080)

* handle jagged beams

* better comment

* bart -- beam search tests print special tokens

* more bart test updates

* more tests!

* better comment
This commit is contained in:
Joao Gante
2025-03-28 17:57:51 +00:00
committed by GitHub
parent 257bc670fb
commit 9fd9476005
5 changed files with 74 additions and 45 deletions

View File

@@ -3931,9 +3931,14 @@ class GenerationMixin:
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
# Crop the static-shaped tensors to the actual size
sequences = sequences[:, :cur_len]
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
# Crop the static-shaped tensors to the actual size.
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
# previous decoding iteration)
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
output_length = decoder_prompt_len + max_generated_length
sequences = sequences[:, :output_length]
beam_indices = beam_indices[:, :max_generated_length]
if return_dict_in_generate:
if not output_scores: