|
|
|
|
@@ -813,6 +813,49 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
else:
|
|
|
|
|
lprobs[i, previous_token] /= repetition_penalty
|
|
|
|
|
|
|
|
|
|
def postprocess_next_token_scores(
|
|
|
|
|
self,
|
|
|
|
|
scores,
|
|
|
|
|
input_ids,
|
|
|
|
|
no_repeat_ngram_size,
|
|
|
|
|
bad_words_ids,
|
|
|
|
|
cur_len,
|
|
|
|
|
min_length,
|
|
|
|
|
max_length,
|
|
|
|
|
eos_token_id,
|
|
|
|
|
repetition_penalty,
|
|
|
|
|
batch_size,
|
|
|
|
|
num_beams,
|
|
|
|
|
):
|
|
|
|
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
|
|
self.enforce_repetition_penalty_(
|
|
|
|
|
scores, batch_size, num_beams, input_ids, repetition_penalty,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# set eos token prob to zero if min_length is not reached
|
|
|
|
|
if eos_token_id is not None and cur_len < min_length:
|
|
|
|
|
scores[:, eos_token_id] = -float("inf")
|
|
|
|
|
|
|
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
|
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
|
|
|
num_batch_hypotheses = batch_size * num_beams
|
|
|
|
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
|
|
|
|
banned_batch_tokens = calc_banned_ngram_tokens(
|
|
|
|
|
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
|
|
|
|
)
|
|
|
|
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
|
|
|
|
scores[i, banned_tokens] = -float("inf")
|
|
|
|
|
|
|
|
|
|
if bad_words_ids is not None:
|
|
|
|
|
# calculate a list of banned tokens according to bad words
|
|
|
|
|
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
|
|
|
|
|
|
|
|
|
for i, banned_tokens in enumerate(banned_tokens):
|
|
|
|
|
scores[i, banned_tokens] = -float("inf")
|
|
|
|
|
|
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def generate(
|
|
|
|
|
self,
|
|
|
|
|
@@ -1222,7 +1265,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
|
|
|
|
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
|
|
|
|
|
|
|
|
|
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
|
|
|
|
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
|
|
|
|
|
|
|
|
|
while cur_len < max_length:
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(
|
|
|
|
|
@@ -1232,40 +1275,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
outputs = self(**model_inputs)
|
|
|
|
|
next_token_logits = outputs[0][:, -1, :]
|
|
|
|
|
|
|
|
|
|
scores = self.postprocess_next_token_scores(
|
|
|
|
|
scores=next_token_logits,
|
|
|
|
|
input_ids=input_ids,
|
|
|
|
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
|
|
|
bad_words_ids=bad_words_ids,
|
|
|
|
|
cur_len=cur_len,
|
|
|
|
|
min_length=min_length,
|
|
|
|
|
max_length=max_length,
|
|
|
|
|
eos_token_id=eos_token_id,
|
|
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
num_beams=1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if model has past, then set the past variable to speed up decoding
|
|
|
|
|
if self._use_cache(outputs, use_cache):
|
|
|
|
|
past = outputs[1]
|
|
|
|
|
|
|
|
|
|
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
|
|
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
|
|
|
|
|
|
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
|
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
|
|
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
|
|
|
|
banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
|
|
|
|
for batch_idx in range(batch_size):
|
|
|
|
|
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
|
|
|
|
|
|
|
|
|
if bad_words_ids is not None:
|
|
|
|
|
# calculate a list of banned tokens according to bad words
|
|
|
|
|
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
|
|
|
|
|
|
|
|
|
for batch_idx in range(batch_size):
|
|
|
|
|
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
|
|
|
|
|
|
|
|
|
# set eos token prob to zero if min_length is not reached
|
|
|
|
|
if eos_token_id is not None and cur_len < min_length:
|
|
|
|
|
next_token_logits[:, eos_token_id] = -float("inf")
|
|
|
|
|
|
|
|
|
|
if do_sample:
|
|
|
|
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
|
|
|
if temperature != 1.0:
|
|
|
|
|
next_token_logits = next_token_logits / temperature
|
|
|
|
|
scores = scores / temperature
|
|
|
|
|
# Top-p/top-k filtering
|
|
|
|
|
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
|
|
|
|
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
|
|
|
|
|
# Sample
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1)
|
|
|
|
|
probs = F.softmax(next_token_logscores, dim=-1)
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
|
|
|
else:
|
|
|
|
|
# Greedy decoding
|
|
|
|
|
@@ -1300,18 +1335,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if there are different sentences lengths in the batch, some batches have to be padded
|
|
|
|
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
|
|
|
|
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
|
|
|
|
# finished sents are filled with pad_token
|
|
|
|
|
decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
|
|
|
|
|
else:
|
|
|
|
|
decoded = input_ids
|
|
|
|
|
|
|
|
|
|
for hypo_idx, hypo in enumerate(input_ids):
|
|
|
|
|
decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]
|
|
|
|
|
|
|
|
|
|
return decoded
|
|
|
|
|
return input_ids
|
|
|
|
|
|
|
|
|
|
def _generate_beam_search(
|
|
|
|
|
self,
|
|
|
|
|
@@ -1357,7 +1381,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
|
|
|
|
|
|
|
|
|
# cache compute states
|
|
|
|
|
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
|
|
|
|
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
|
|
|
|
|
|
|
|
|
# done sentences
|
|
|
|
|
done = [False for _ in range(batch_size)]
|
|
|
|
|
@@ -1373,43 +1397,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
if self._use_cache(outputs, use_cache):
|
|
|
|
|
past = outputs[1]
|
|
|
|
|
|
|
|
|
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
|
|
self.enforce_repetition_penalty_(
|
|
|
|
|
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
|
|
|
|
|
)
|
|
|
|
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
|
|
|
|
|
|
|
|
if temperature != 1.0:
|
|
|
|
|
next_token_logits = next_token_logits / temperature
|
|
|
|
|
scores = self.postprocess_next_token_scores(
|
|
|
|
|
scores=scores,
|
|
|
|
|
input_ids=input_ids,
|
|
|
|
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
|
|
|
bad_words_ids=bad_words_ids,
|
|
|
|
|
cur_len=cur_len,
|
|
|
|
|
min_length=min_length,
|
|
|
|
|
max_length=max_length,
|
|
|
|
|
eos_token_id=eos_token_id,
|
|
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
num_beams=num_beams,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.config.is_encoder_decoder and do_sample is False:
|
|
|
|
|
# TODO (PVP) still a bit hacky here - there might be a better solution
|
|
|
|
|
next_token_logits = self.prepare_logits_for_generation(
|
|
|
|
|
next_token_logits, cur_len=cur_len, max_length=max_length
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
|
|
|
|
|
|
|
|
# set eos token prob to zero if min_length is not reached
|
|
|
|
|
if eos_token_id is not None and cur_len < min_length:
|
|
|
|
|
scores[:, eos_token_id] = -float("inf")
|
|
|
|
|
|
|
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
|
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
|
|
|
num_batch_hypotheses = batch_size * num_beams
|
|
|
|
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
|
|
|
|
banned_batch_tokens = calc_banned_ngram_tokens(
|
|
|
|
|
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
|
|
|
|
)
|
|
|
|
|
for i, banned_tokens in enumerate(banned_batch_tokens):
|
|
|
|
|
scores[i, banned_tokens] = -float("inf")
|
|
|
|
|
|
|
|
|
|
if bad_words_ids is not None:
|
|
|
|
|
# calculate a list of banned tokens according to bad words
|
|
|
|
|
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
|
|
|
|
|
|
|
|
|
for i, banned_tokens in enumerate(banned_tokens):
|
|
|
|
|
scores[i, banned_tokens] = -float("inf")
|
|
|
|
|
scores = self.prepare_logits_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
|
|
|
|
|
|
|
|
|
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
|
|
|
|
scores.shape, (batch_size * num_beams, vocab_size)
|
|
|
|
|
@@ -1417,6 +1423,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
|
|
|
|
|
if do_sample:
|
|
|
|
|
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
|
|
|
|
# Temperature
|
|
|
|
|
if temperature != 1.0:
|
|
|
|
|
_scores = _scores / temperature
|
|
|
|
|
# Top-p/top-k filtering
|
|
|
|
|
_scores = top_k_top_p_filtering(
|
|
|
|
|
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
|
|
|
|
|