[generate] beam search -- fix output cropping (#37080)
* handle jagged beams * better comment * bart -- beam search tests print special tokens * more bart test updates * more tests! * better comment
This commit is contained in:
@@ -3931,9 +3931,14 @@ class GenerationMixin:
|
||||
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
|
||||
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
|
||||
|
||||
# Crop the static-shaped tensors to the actual size
|
||||
sequences = sequences[:, :cur_len]
|
||||
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
|
||||
# Crop the static-shaped tensors to the actual size.
|
||||
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
|
||||
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
|
||||
# previous decoding iteration)
|
||||
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
|
||||
output_length = decoder_prompt_len + max_generated_length
|
||||
sequences = sequences[:, :output_length]
|
||||
beam_indices = beam_indices[:, :max_generated_length]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
|
||||
Reference in New Issue
Block a user