[Generation] Fix Transition probs (#17311)
* [Draft] fix transition probs * up * up * up * make it work * fix * finish * update
This commit is contained in:
committed by
GitHub
parent
e8714c0307
commit
518bd02c9b
@@ -212,6 +212,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
next_indices: torch.LongTensor,
|
next_indices: torch.LongTensor,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor]:
|
) -> Tuple[torch.Tensor]:
|
||||||
cur_len = input_ids.shape[-1]
|
cur_len = input_ids.shape[-1]
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps)
|
||||||
@@ -256,9 +257,16 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
||||||
if is_beam_token_worse_than_top_num_beams:
|
if is_beam_token_worse_than_top_num_beams:
|
||||||
continue
|
continue
|
||||||
|
if beam_indices is not None:
|
||||||
|
beam_index = beam_indices[batch_beam_idx]
|
||||||
|
beam_index = beam_index + (next_index,)
|
||||||
|
else:
|
||||||
|
beam_index = None
|
||||||
|
|
||||||
beam_hyp.add(
|
beam_hyp.add(
|
||||||
input_ids[batch_beam_idx].clone(),
|
input_ids[batch_beam_idx].clone(),
|
||||||
next_score.item(),
|
next_score.item(),
|
||||||
|
beam_indices=beam_index,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# add next predicted token since it is not eos_token
|
# add next predicted token since it is not eos_token
|
||||||
@@ -299,6 +307,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
max_length: int,
|
max_length: int,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
|
beam_indices: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.LongTensor]:
|
) -> Tuple[torch.LongTensor]:
|
||||||
batch_size = len(self._beam_hyps)
|
batch_size = len(self._beam_hyps)
|
||||||
|
|
||||||
@@ -313,11 +322,13 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||||
final_score = final_beam_scores[batch_beam_idx].item()
|
final_score = final_beam_scores[batch_beam_idx].item()
|
||||||
final_tokens = input_ids[batch_beam_idx]
|
final_tokens = input_ids[batch_beam_idx]
|
||||||
beam_hyp.add(final_tokens, final_score)
|
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
||||||
|
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
|
||||||
|
|
||||||
# select the best hypotheses
|
# select the best hypotheses
|
||||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||||
best = []
|
best = []
|
||||||
|
best_indices = []
|
||||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
# retrieve best hypotheses
|
# retrieve best hypotheses
|
||||||
@@ -327,23 +338,42 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
best_hyp_tuple = sorted_hyps.pop()
|
best_hyp_tuple = sorted_hyps.pop()
|
||||||
best_score = best_hyp_tuple[0]
|
best_score = best_hyp_tuple[0]
|
||||||
best_hyp = best_hyp_tuple[1]
|
best_hyp = best_hyp_tuple[1]
|
||||||
|
best_index = best_hyp_tuple[2]
|
||||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||||
|
|
||||||
# append to lists
|
# append hyp to lists
|
||||||
best.append(best_hyp)
|
best.append(best_hyp)
|
||||||
|
|
||||||
|
# append indices to list
|
||||||
|
best_indices.append(best_index)
|
||||||
|
|
||||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||||
|
|
||||||
# prepare for adding eos
|
# prepare for adding eos
|
||||||
sent_lengths_max = sent_lengths.max().item() + 1
|
sent_lengths_max = sent_lengths.max().item() + 1
|
||||||
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||||
|
|
||||||
|
if len(best_indices) > 0 and best_indices[0] is not None:
|
||||||
|
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||||
|
else:
|
||||||
|
indices = None
|
||||||
|
|
||||||
# shorter batches are padded if needed
|
# shorter batches are padded if needed
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||||
decoded.fill_(pad_token_id)
|
decoded.fill_(pad_token_id)
|
||||||
|
|
||||||
|
if indices is not None:
|
||||||
|
indices.fill_(-1)
|
||||||
|
|
||||||
# fill with hypotheses and eos_token_id if the latter fits in
|
# fill with hypotheses and eos_token_id if the latter fits in
|
||||||
for i, hypo in enumerate(best):
|
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
||||||
decoded[i, : sent_lengths[i]] = hypo
|
decoded[i, : sent_lengths[i]] = hypo
|
||||||
|
|
||||||
|
if indices is not None:
|
||||||
|
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
||||||
|
|
||||||
if sent_lengths[i] < sent_max_len:
|
if sent_lengths[i] < sent_max_len:
|
||||||
decoded[i, sent_lengths[i]] = eos_token_id
|
decoded[i, sent_lengths[i]] = eos_token_id
|
||||||
|
|
||||||
@@ -351,6 +381,7 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
{
|
{
|
||||||
"sequences": decoded,
|
"sequences": decoded,
|
||||||
"sequence_scores": best_scores,
|
"sequence_scores": best_scores,
|
||||||
|
"beam_indices": indices,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -789,6 +820,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
|
|
||||||
# prepare for adding eos
|
# prepare for adding eos
|
||||||
sent_lengths_max = sent_lengths.max().item() + 1
|
sent_lengths_max = sent_lengths.max().item() + 1
|
||||||
|
|
||||||
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
||||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||||
# shorter batches are padded if needed
|
# shorter batches are padded if needed
|
||||||
@@ -801,6 +833,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
decoded[i, : sent_lengths[i]] = hypo
|
decoded[i, : sent_lengths[i]] = hypo
|
||||||
if sent_lengths[i] < sent_max_len:
|
if sent_lengths[i] < sent_max_len:
|
||||||
decoded[i, sent_lengths[i]] = eos_token_id
|
decoded[i, sent_lengths[i]] = eos_token_id
|
||||||
|
|
||||||
return UserDict(
|
return UserDict(
|
||||||
{
|
{
|
||||||
"sequences": decoded,
|
"sequences": decoded,
|
||||||
@@ -826,15 +859,15 @@ class BeamHypotheses:
|
|||||||
"""
|
"""
|
||||||
return len(self.beams)
|
return len(self.beams)
|
||||||
|
|
||||||
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
|
def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
|
||||||
"""
|
"""
|
||||||
Add a new hypothesis to the list.
|
Add a new hypothesis to the list.
|
||||||
"""
|
"""
|
||||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||||
if len(self) < self.num_beams or score > self.worst_score:
|
if len(self) < self.num_beams or score > self.worst_score:
|
||||||
self.beams.append((score, hyp))
|
self.beams.append((score, hyp, beam_indices))
|
||||||
if len(self) > self.num_beams:
|
if len(self) > self.num_beams:
|
||||||
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
|
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
|
||||||
del self.beams[sorted_next_scores[0][1]]
|
del self.beams[sorted_next_scores[0][1]]
|
||||||
self.worst_score = sorted_next_scores[1][0]
|
self.worst_score = sorted_next_scores[1][0]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -217,8 +217,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||||
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
`(batch_size*num_return_sequences, input_ids.shape[-1])`.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||||
@@ -230,7 +230,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
beam_indices: Optional[torch.LongTensor] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
@@ -254,8 +254,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||||
config.vocab_size)`).
|
config.vocab_size)`).
|
||||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||||
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
`(batch_size*num_return_sequences, max_length-1)`.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||||
@@ -278,7 +278,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
beam_indices: Optional[torch.LongTensor] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
@@ -303,8 +303,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
|||||||
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
|
||||||
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||||
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
`(batch_size*num_return_sequences, input_ids.shape[-1])`.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||||
@@ -316,7 +316,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
beam_indices: Optional[torch.LongTensor] = None
|
||||||
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
|
||||||
@@ -339,9 +339,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||||
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
|
||||||
config.vocab_size)`).
|
config.vocab_size)`).
|
||||||
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
||||||
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||||
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
|
`(batch_size*num_return_sequences, max_length-1)`.
|
||||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length, sequence_length)`.
|
||||||
@@ -362,7 +362,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
sequences_scores: Optional[torch.FloatTensor] = None
|
sequences_scores: Optional[torch.FloatTensor] = None
|
||||||
scores: Optional[Tuple[torch.FloatTensor]] = None
|
scores: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
|
beam_indices: Optional[torch.LongTensor] = None
|
||||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
@@ -811,32 +811,33 @@ class GenerationMixin:
|
|||||||
"""compute the transition probabilities of sequences given generation
|
"""compute the transition probabilities of sequences given generation
|
||||||
scores and beam indices"""
|
scores and beam indices"""
|
||||||
|
|
||||||
# reshape scores as [vocab_size * batch_size, # generation steps]
|
# 1. reshape scores as [vocab_size * batch_size, # generation steps]
|
||||||
# with batch_size being 2 * vocab_size and # generation steps being
|
# with batch_size being 2 * vocab_size and # generation steps being
|
||||||
# seq_len - input_length
|
# seq_len - input_length
|
||||||
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
|
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
|
||||||
|
|
||||||
# start of generated tokens
|
# 2. cut beam_indices to longest beam length
|
||||||
cut_idx = sequences.shape[-1] - scores.shape[-1]
|
beam_indices_mask = beam_indices < 0
|
||||||
# adjust for beam indices
|
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
|
||||||
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
|
beam_indices = beam_indices[:, :max_beam_length]
|
||||||
# compute real indices
|
beam_indices_mask = beam_indices_mask[:, :max_beam_length]
|
||||||
indices = sequences[:, cut_idx:] + beam_sequence_indices
|
|
||||||
# gather scores and run
|
|
||||||
transition_scores = scores.gather(0, indices)
|
|
||||||
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
|
|
||||||
# get first occurence of EOS token
|
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
||||||
|
|
||||||
if eos_token_id is not None:
|
# 3. Set indices of beams that finished early to 0
|
||||||
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
|
# such indices will be masked correctly afterwards
|
||||||
# make sure first eos token still contributes to transition probs
|
beam_indices[beam_indices_mask] = 0
|
||||||
is_eos_token_id[:, -1] = False
|
|
||||||
is_eos_token_id = is_eos_token_id.roll(1, -1)
|
# 4. multiply beam_indices with vocab size to gather correctly from scores
|
||||||
# all indices after eos shoud be masked
|
beam_sequence_indices = beam_indices * self.config.vocab_size
|
||||||
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
|
|
||||||
# zero out padded probs
|
# 5. Define which indices contributed to scores
|
||||||
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
|
cut_idx = sequences.shape[-1] - max_beam_length
|
||||||
|
indices = sequences[:, cut_idx:] + beam_sequence_indices
|
||||||
|
|
||||||
|
# 6. Compute scores
|
||||||
|
transition_scores = scores.gather(0, indices)
|
||||||
|
|
||||||
|
# 7. Mask out transition_scores of beams that stopped early
|
||||||
|
transition_scores[beam_indices_mask] = 0
|
||||||
|
|
||||||
return transition_scores
|
return transition_scores
|
||||||
|
|
||||||
@@ -2256,6 +2257,7 @@ class GenerationMixin:
|
|||||||
next_indices,
|
next_indices,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
beam_indices=beam_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
beam_scores = beam_outputs["next_beam_scores"]
|
beam_scores = beam_outputs["next_beam_scores"]
|
||||||
@@ -2290,25 +2292,19 @@ class GenerationMixin:
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
max_length=stopping_criteria.max_length,
|
max_length=stopping_criteria.max_length,
|
||||||
|
beam_indices=beam_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
sequence_outputs["sequence_scores"] = None
|
||||||
else:
|
|
||||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
|
||||||
# return only as many indices as sequences
|
|
||||||
beam_indices = tuple(
|
|
||||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
|
||||||
)
|
|
||||||
beam_indices = sum(beam_indices, ())
|
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSearchEncoderDecoderOutput(
|
return BeamSearchEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
beam_indices=beam_indices,
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
@@ -2320,7 +2316,7 @@ class GenerationMixin:
|
|||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
beam_indices=beam_indices,
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
@@ -2580,6 +2576,7 @@ class GenerationMixin:
|
|||||||
next_indices,
|
next_indices,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
beam_indices=beam_indices,
|
||||||
)
|
)
|
||||||
beam_scores = beam_outputs["next_beam_scores"]
|
beam_scores = beam_outputs["next_beam_scores"]
|
||||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||||
@@ -2613,25 +2610,19 @@ class GenerationMixin:
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
max_length=stopping_criteria.max_length,
|
max_length=stopping_criteria.max_length,
|
||||||
|
beam_indices=beam_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
sequence_outputs["sequence_scores"] = None
|
||||||
else:
|
|
||||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
|
||||||
# return only as many indices as sequences
|
|
||||||
beam_indices = tuple(
|
|
||||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
|
||||||
)
|
|
||||||
beam_indices = sum(beam_indices, ())
|
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSampleEncoderDecoderOutput(
|
return BeamSampleEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
beam_indices=beam_indices,
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
@@ -2643,7 +2634,7 @@ class GenerationMixin:
|
|||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
beam_indices=beam_indices,
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
@@ -2909,6 +2900,7 @@ class GenerationMixin:
|
|||||||
next_tokens = next_tokens % vocab_size
|
next_tokens = next_tokens % vocab_size
|
||||||
|
|
||||||
# stateless
|
# stateless
|
||||||
|
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||||
beam_outputs = beam_scorer.process(
|
beam_outputs = beam_scorer.process(
|
||||||
group_input_ids,
|
group_input_ids,
|
||||||
next_token_scores,
|
next_token_scores,
|
||||||
@@ -2916,6 +2908,7 @@ class GenerationMixin:
|
|||||||
next_indices,
|
next_indices,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
|
beam_indices=process_beam_indices,
|
||||||
)
|
)
|
||||||
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
||||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||||
@@ -2971,6 +2964,7 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
|
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
||||||
sequence_outputs = beam_scorer.finalize(
|
sequence_outputs = beam_scorer.finalize(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scores,
|
beam_scores,
|
||||||
@@ -2979,26 +2973,19 @@ class GenerationMixin:
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
max_length=stopping_criteria.max_length,
|
max_length=stopping_criteria.max_length,
|
||||||
|
beam_indices=final_beam_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if not output_scores:
|
if not output_scores:
|
||||||
sequence_outputs["sequence_scores"] = None
|
sequence_outputs["sequence_scores"] = None
|
||||||
else:
|
|
||||||
beam_indices = sum(beam_indices, ())
|
|
||||||
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
|
|
||||||
# return only as many indices as sequences
|
|
||||||
beam_indices = tuple(
|
|
||||||
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
|
|
||||||
)
|
|
||||||
beam_indices = sum(beam_indices, ())
|
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return BeamSearchEncoderDecoderOutput(
|
return BeamSearchEncoderDecoderOutput(
|
||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
beam_indices=beam_indices,
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
encoder_attentions=encoder_attentions,
|
encoder_attentions=encoder_attentions,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
decoder_attentions=decoder_attentions,
|
decoder_attentions=decoder_attentions,
|
||||||
@@ -3010,6 +2997,7 @@ class GenerationMixin:
|
|||||||
sequences=sequence_outputs["sequences"],
|
sequences=sequence_outputs["sequences"],
|
||||||
sequences_scores=sequence_outputs["sequence_scores"],
|
sequences_scores=sequence_outputs["sequence_scores"],
|
||||||
scores=scores,
|
scores=scores,
|
||||||
|
beam_indices=sequence_outputs["beam_indices"],
|
||||||
attentions=decoder_attentions,
|
attentions=decoder_attentions,
|
||||||
hidden_states=decoder_hidden_states,
|
hidden_states=decoder_hidden_states,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -126,7 +126,11 @@ class BeamSearchTester:
|
|||||||
|
|
||||||
tokens = next_tokens.clone()
|
tokens = next_tokens.clone()
|
||||||
tokens[:, : self.num_beams] = self.eos_token_id
|
tokens[:, : self.num_beams] = self.eos_token_id
|
||||||
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
|
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
|
||||||
|
beam_indices = tuple(tuple(b) for b in beam_indices)
|
||||||
|
beam_scorer.process(
|
||||||
|
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
|
||||||
|
)
|
||||||
# beam scorer should be done
|
# beam scorer should be done
|
||||||
self.parent.assertTrue(beam_scorer.is_done)
|
self.parent.assertTrue(beam_scorer.is_done)
|
||||||
|
|
||||||
@@ -136,7 +140,7 @@ class BeamSearchTester:
|
|||||||
tokens = next_tokens.clone()
|
tokens = next_tokens.clone()
|
||||||
tokens[:, 1] = self.eos_token_id
|
tokens[:, 1] = self.eos_token_id
|
||||||
beam_outputs = beam_scorer.process(
|
beam_outputs = beam_scorer.process(
|
||||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
|
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
|
||||||
)
|
)
|
||||||
output_scores = beam_outputs["next_beam_scores"]
|
output_scores = beam_outputs["next_beam_scores"]
|
||||||
output_tokens = beam_outputs["next_beam_tokens"]
|
output_tokens = beam_outputs["next_beam_tokens"]
|
||||||
@@ -161,10 +165,15 @@ class BeamSearchTester:
|
|||||||
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
||||||
|
|
||||||
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
||||||
|
expected_beam_indices = list(range(10))
|
||||||
for batch_idx in range(self.batch_size):
|
for batch_idx in range(self.batch_size):
|
||||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||||
|
)
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
expected_beam_indices + [next_indices[batch_idx, 1].item()],
|
||||||
|
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||||
@@ -188,6 +197,8 @@ class BeamSearchTester:
|
|||||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
# finalize
|
# finalize
|
||||||
|
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
|
||||||
|
beam_indices = tuple(tuple(b) for b in beam_indices)
|
||||||
sequence_output = beam_scorer.finalize(
|
sequence_output = beam_scorer.finalize(
|
||||||
input_ids,
|
input_ids,
|
||||||
output_scores,
|
output_scores,
|
||||||
@@ -196,6 +207,7 @@ class BeamSearchTester:
|
|||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
|
beam_indices=beam_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequences = sequence_output["sequences"]
|
sequences = sequence_output["sequences"]
|
||||||
@@ -225,6 +237,7 @@ class BeamSearchTester:
|
|||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
|
beam_indices=beam_indices,
|
||||||
)
|
)
|
||||||
sequences = sequence_output["sequences"]
|
sequences = sequence_output["sequences"]
|
||||||
sequence_scores = sequence_output["sequence_scores"]
|
sequence_scores = sequence_output["sequence_scores"]
|
||||||
@@ -394,7 +407,7 @@ class ConstrainedBeamSearchTester:
|
|||||||
for batch_idx in range(self.batch_size):
|
for batch_idx in range(self.batch_size):
|
||||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_constrained_beam_scorer_finalize(
|
def check_constrained_beam_scorer_finalize(
|
||||||
|
|||||||
@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_transition_scores_early_stopping(self):
|
||||||
|
# This is an aggressive test that makes sure that `beam_search's`
|
||||||
|
# transition scores are computed correctly for varying `num_return_sequences`,
|
||||||
|
# `num_beams` and `batch_size > 1`
|
||||||
|
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
|
||||||
|
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
|
||||||
|
|
||||||
|
result = model.generate(
|
||||||
|
input_ids,
|
||||||
|
max_length=10,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
forced_eos_token_id=model.config.eos_token_id,
|
||||||
|
num_beams=4,
|
||||||
|
do_sample=False,
|
||||||
|
num_return_sequences=3,
|
||||||
|
length_penalty=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
transition_scores = model.compute_transition_beam_scores(
|
||||||
|
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
sum_transition_scores = torch.sum(transition_scores, dim=1)
|
||||||
|
|
||||||
|
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
|
||||||
|
|
||||||
|
def test_log_scores_sample_decoder_only(self):
|
||||||
|
articles = ["I need input_ids to generate", "Short and"]
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
|
||||||
|
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
result = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_length=15,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
do_sample=False,
|
||||||
|
output_scores=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder-only starts generating from `input_ids`
|
||||||
|
begin_generation = inputs.input_ids.shape[-1]
|
||||||
|
|
||||||
|
gen_sequences = result.sequences[:, begin_generation:]
|
||||||
|
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||||
|
|
||||||
|
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||||
|
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||||
|
|
||||||
|
def test_log_scores_sample_encoder_decoder(self):
|
||||||
|
articles = ["I need input_ids to generate", "Short and"]
|
||||||
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
|
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
|
||||||
|
|
||||||
|
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
result = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_length=3,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
output_scores=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# encoder-decoder has one decoder_start_token_id by default
|
||||||
|
begin_generation = 1
|
||||||
|
|
||||||
|
gen_sequences = result.sequences[:, begin_generation:]
|
||||||
|
probs = torch.stack(result.scores, dim=1).softmax(-1)
|
||||||
|
|
||||||
|
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
|
||||||
|
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_beam_search_example_integration(self):
|
def test_beam_search_example_integration(self):
|
||||||
# exactly the example provided in the docstrings of beam search, which previously
|
# exactly the example provided in the docstrings of beam search, which previously
|
||||||
@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_constrained_beam_search(self):
|
def test_constrained_beam_search(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||||
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||||
@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_constrained_beam_search_mixed(self):
|
def test_constrained_beam_search_mixed(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
|
||||||
flexible_phrases = tokenizer(
|
flexible_phrases = tokenizer(
|
||||||
@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_constrained_beam_search_mixed_mixin(self):
|
def test_constrained_beam_search_mixed_mixin(self):
|
||||||
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device)
|
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
force_word = "scared"
|
force_word = "scared"
|
||||||
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
force_flexible = ["scream", "screams", "screaming", "screamed"]
|
||||||
|
|||||||
Reference in New Issue
Block a user