[Generation] Fix Transition probs (#17311)
* [Draft] fix transition probs * up * up * up * make it work * fix * finish * update
This commit is contained in:
committed by
GitHub
parent
e8714c0307
commit
518bd02c9b
@@ -212,6 +212,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
next_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
cur_len = input_ids.shape[-1]
|
||||
batch_size = len(self._beam_hyps)
|
||||
@@ -256,9 +257,16 @@ class BeamSearchScorer(BeamScorer):
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
if beam_indices is not None:
|
||||
beam_index = beam_indices[batch_beam_idx]
|
||||
beam_index = beam_index + (next_index,)
|
||||
else:
|
||||
beam_index = None
|
||||
|
||||
beam_hyp.add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
beam_indices=beam_index,
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
@@ -299,6 +307,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
max_length: int,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
beam_indices: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
@@ -313,11 +322,13 @@ class BeamSearchScorer(BeamScorer):
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_hyp.add(final_tokens, final_score)
|
||||
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
best = []
|
||||
best_indices = []
|
||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||
|
||||
# retrieve best hypotheses
|
||||
@@ -327,23 +338,42 @@ class BeamSearchScorer(BeamScorer):
|
||||
best_hyp_tuple = sorted_hyps.pop()
|
||||
best_score = best_hyp_tuple[0]
|
||||
best_hyp = best_hyp_tuple[1]
|
||||
best_index = best_hyp_tuple[2]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
|
||||
# append to lists
|
||||
# append hyp to lists
|
||||
best.append(best_hyp)
|
||||
|
||||
# append indices to list
|
||||
best_indices.append(best_index)
|
||||
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_lengths_max = sent_lengths.max().item() + 1
|
||||
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
|
||||
if len(best_indices) > 0 and best_indices[0] is not None:
|
||||
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
else:
|
||||
indices = None
|
||||
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
if indices is not None:
|
||||
indices.fill_(-1)
|
||||
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
|
||||
if indices is not None:
|
||||
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
||||
|
||||
if sent_lengths[i] < sent_max_len:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
|
||||
@@ -351,6 +381,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
{
|
||||
"sequences": decoded,
|
||||
"sequence_scores": best_scores,
|
||||
"beam_indices": indices,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -789,6 +820,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
|
||||
# prepare for adding eos
|
||||
sent_lengths_max = sent_lengths.max().item() + 1
|
||||
|
||||
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
# shorter batches are padded if needed
|
||||
@@ -801,6 +833,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < sent_max_len:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"sequences": decoded,
|
||||
@@ -826,15 +859,15 @@ class BeamHypotheses:
|
||||
"""
|
||||
return len(self.beams)
|
||||
|
||||
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
|
||||
def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp))
|
||||
self.beams.append((score, hyp, beam_indices))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
|
||||
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
|
||||
del self.beams[sorted_next_scores[0][1]]
|
||||
self.worst_score = sorted_next_scores[1][0]
|
||||
else:
|
||||
|
||||
@@ -217,8 +217,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
||||
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||
`(batch_size*num_return_sequences, input_ids.shape[-1])`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
@@ -230,7 +230,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
beam_indices: Optional[torch.LongTensor] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
@@ -254,8 +254,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||
config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||
`(batch_size*num_return_sequences, max_length-1)`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||
@@ -278,7 +278,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
beam_indices: Optional[torch.LongTensor] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
@@ -303,8 +303,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
||||
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||
`(batch_size*num_return_sequences, input_ids.shape[-1])`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
@@ -316,7 +316,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
beam_indices: Optional[torch.LongTensor] = None
|
||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
|
||||
@@ -339,9 +339,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||
config.vocab_size)`).
|
||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
||||
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
||||
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||
`(batch_size*num_return_sequences, max_length-1)`.
|
||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
@@ -362,7 +362,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
sequences: torch.LongTensor = None
|
||||
sequences_scores: Optional[torch.FloatTensor] = None
|
||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
||||
beam_indices: Optional[torch.LongTensor] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
@@ -811,32 +811,33 @@ class GenerationMixin:
|
||||
"""compute the transition probabilities of sequences given generation
|
||||
scores and beam indices"""
|
||||
|
||||
# reshape scores as [vocab_size * batch_size, # generation steps]
|
||||
# 1. reshape scores as [vocab_size * batch_size, # generation steps]
|
||||
# with batch_size being 2 * vocab_size and # generation steps being
|
||||
# seq_len - input_length
|
||||
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
|
||||
|
||||
# start of generated tokens
|
||||
cut_idx = sequences.shape[-1] - scores.shape[-1]
|
||||
# adjust for beam indices
|
||||
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
|
||||
# compute real indices
|
||||
indices = sequences[:, cut_idx:] + beam_sequence_indices
|
||||
# gather scores and run
|
||||
transition_scores = scores.gather(0, indices)
|
||||
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
|
||||
# get first occurence of EOS token
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
# 2. cut beam_indices to longest beam length
|
||||
beam_indices_mask = beam_indices < 0
|
||||
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
|
||||
beam_indices = beam_indices[:, :max_beam_length]
|
||||
beam_indices_mask = beam_indices_mask[:, :max_beam_length]
|
||||
|
||||
if eos_token_id is not None:
|
||||
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
|
||||
# make sure first eos token still contributes to transition probs
|
||||
is_eos_token_id[:, -1] = False
|
||||
is_eos_token_id = is_eos_token_id.roll(1, -1)
|
||||
# all indices after eos shoud be masked
|
||||
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
|
||||
# zero out padded probs
|
||||
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
|
||||
# 3. Set indices of beams that finished early to 0
|
||||
# such indices will be masked correctly afterwards
|
||||
beam_indices[beam_indices_mask] = 0
|
||||
|
||||
# 4. multiply beam_indices with vocab size to gather correctly from scores
|
||||
beam_sequence_indices = beam_indices * self.config.vocab_size
|
||||
|
||||
# 5. Define which indices contributed to scores
|
||||
cut_idx = sequences.shape[-1] - max_beam_length
|
||||
indices = sequences[:, cut_idx:] + beam_sequence_indices
|
||||
|
||||
# 6. Compute scores
|
||||
transition_scores = scores.gather(0, indices)
|
||||
|
||||
# 7. Mask out transition_scores of beams that stopped early
|
||||
transition_scores[beam_indices_mask] = 0
|
||||
|
||||
return transition_scores
|
||||
|
||||
@@ -2256,6 +2257,7 @@ class GenerationMixin:
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
@@ -2290,25 +2292,19 @@ class GenerationMixin:
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
else:
|
||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||
# return only as many indices as sequences
|
||||
beam_indices = tuple(
|
||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||
)
|
||||
beam_indices = sum(beam_indices, ())
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSearchEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
@@ -2320,7 +2316,7 @@ class GenerationMixin:
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
@@ -2580,6 +2576,7 @@ class GenerationMixin:
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
@@ -2613,25 +2610,19 @@ class GenerationMixin:
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
else:
|
||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||
# return only as many indices as sequences
|
||||
beam_indices = tuple(
|
||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||
)
|
||||
beam_indices = sum(beam_indices, ())
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSampleEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
@@ -2643,7 +2634,7 @@ class GenerationMixin:
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
@@ -2909,6 +2900,7 @@ class GenerationMixin:
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||
beam_outputs = beam_scorer.process(
|
||||
group_input_ids,
|
||||
next_token_scores,
|
||||
@@ -2916,6 +2908,7 @@ class GenerationMixin:
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=process_beam_indices,
|
||||
)
|
||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
@@ -2971,6 +2964,7 @@ class GenerationMixin:
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids,
|
||||
beam_scores,
|
||||
@@ -2979,26 +2973,19 @@ class GenerationMixin:
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=final_beam_indices,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
else:
|
||||
beam_indices = sum(beam_indices, ())
|
||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
||||
# return only as many indices as sequences
|
||||
beam_indices = tuple(
|
||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
||||
)
|
||||
beam_indices = sum(beam_indices, ())
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return BeamSearchEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=beam_indices,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
@@ -3010,6 +2997,7 @@ class GenerationMixin:
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
|
||||
@@ -126,7 +126,11 @@ class BeamSearchTester:
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, : self.num_beams] = self.eos_token_id
|
||||
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
|
||||
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
|
||||
beam_indices = tuple(tuple(b) for b in beam_indices)
|
||||
beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
|
||||
)
|
||||
# beam scorer should be done
|
||||
self.parent.assertTrue(beam_scorer.is_done)
|
||||
|
||||
@@ -136,7 +140,7 @@ class BeamSearchTester:
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, 1] = self.eos_token_id
|
||||
beam_outputs = beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
@@ -161,10 +165,15 @@ class BeamSearchTester:
|
||||
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
||||
|
||||
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
||||
expected_beam_indices = list(range(10))
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
expected_beam_indices + [next_indices[batch_idx, 1].item()],
|
||||
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
|
||||
)
|
||||
|
||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
@@ -188,6 +197,8 @@ class BeamSearchTester:
|
||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# finalize
|
||||
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
|
||||
beam_indices = tuple(tuple(b) for b in beam_indices)
|
||||
sequence_output = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
@@ -196,6 +207,7 @@ class BeamSearchTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
@@ -225,6 +237,7 @@ class BeamSearchTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
beam_indices=beam_indices,
|
||||
)
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
@@ -394,7 +407,7 @@ class ConstrainedBeamSearchTester:
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
||||
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||
)
|
||||
|
||||
def check_constrained_beam_scorer_finalize(
|
||||
|
||||
@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_transition_scores_early_stopping(self):
|
||||
# This is an aggressive test that makes sure that `beam_search's`
|
||||
# transition scores are computed correctly for varying `num_return_sequences`,
|
||||
# `num_beams` and `batch_size > 1`
|
||||
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
|
||||
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
input_ids,
|
||||
max_length=10,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
forced_eos_token_id=model.config.eos_token_id,
|
||||
num_beams=4,
|
||||
do_sample=False,
|
||||
num_return_sequences=3,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_beam_scores(
|
||||
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
|
||||
)
|
||||
|
||||
sum_transition_scores = torch.sum(transition_scores, dim=1)
|
||||
|
||||
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
|
||||
|
||||
def test_log_scores_sample_decoder_only(self):
|
||||
articles = ["I need input_ids to generate", "Short and"]
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
**inputs,
|
||||
max_length=15,
|
||||
return_dict_in_generate=True,
|
||||
do_sample=False,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# decoder-only starts generating from `input_ids`
|
||||
begin_generation = inputs.input_ids.shape[-1]
|
||||
|
||||
gen_sequences = result.sequences[:, begin_generation:]
|
||||
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||
|
||||
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
|
||||
|
||||
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||
|
||||
def test_log_scores_sample_encoder_decoder(self):
|
||||
articles = ["I need input_ids to generate", "Short and"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
|
||||
|
||||
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
result = model.generate(
|
||||
**inputs,
|
||||
max_length=3,
|
||||
return_dict_in_generate=True,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
# encoder-decoder has one decoder_start_token_id by default
|
||||
begin_generation = 1
|
||||
|
||||
gen_sequences = result.sequences[:, begin_generation:]
|
||||
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||
|
||||
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
|
||||
|
||||
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# exactly the example provided in the docstrings of beam search, which previously
|
||||
@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
flexible_phrases = tokenizer(
|
||||
@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
force_word = "scared"
|
||||
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
||||
|
||||
Reference in New Issue
Block a user