Black 20 release
This commit is contained in:
@@ -83,7 +83,11 @@ class GenerationMixin:
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
self.enforce_repetition_penalty_(
|
||||
scores, batch_size, num_beams, input_ids, repetition_penalty,
|
||||
scores,
|
||||
batch_size,
|
||||
num_beams,
|
||||
input_ids,
|
||||
repetition_penalty,
|
||||
)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
@@ -324,7 +328,10 @@ class GenerationMixin:
|
||||
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
||||
)
|
||||
input_ids = torch.full(
|
||||
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
|
||||
(batch_size, 1),
|
||||
bos_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
else:
|
||||
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||
@@ -514,8 +521,8 @@ class GenerationMixin:
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
"""Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
"""
|
||||
# length of generated sentences / unfinished sentences
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
@@ -619,8 +626,7 @@ class GenerationMixin:
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
"""Generate sequences for each example with beam search."""
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
@@ -749,7 +755,8 @@ class GenerationMixin:
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
||||
input_ids[effective_beam_id].clone(),
|
||||
beam_token_score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
@@ -806,7 +813,8 @@ class GenerationMixin:
|
||||
assert torch.all(
|
||||
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
||||
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
|
||||
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
|
||||
next_scores[:, :num_beams][batch_idx],
|
||||
beam_scores.view(batch_size, num_beams)[batch_idx],
|
||||
)
|
||||
|
||||
# need to add best num_beams hypotheses to generated hyps
|
||||
@@ -916,7 +924,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
||||
|
||||
|
||||
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
||||
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
|
||||
"""Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
|
||||
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
|
||||
Args:
|
||||
scores: logits distribution of shape (batch size, vocabulary size)
|
||||
@@ -946,14 +954,14 @@ def top_k_top_p_filtering(
|
||||
filter_value: float = -float("Inf"),
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> Tensor:
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
if top_k > 0:
|
||||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
||||
|
||||
Reference in New Issue
Block a user