fix conflicts
This commit is contained in:
committed by
Patrick von Platen
parent
77e6775065
commit
c62444da39
@@ -587,6 +587,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, **kwargs):
|
||||
return scores
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
||||
has_output_past = getattr(self.config, "output_past", False)
|
||||
@@ -940,20 +943,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if repetition_penalty != 1.0:
|
||||
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
||||
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
|
||||
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
||||
for batch_idx in range(batch_size):
|
||||
next_token_logits[
|
||||
batch_idx, banned_tokens[batch_idx]
|
||||
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||
] = -float('inf')
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
next_token_logits[
|
||||
:, eos_token_id
|
||||
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||
] = -float('inf')
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -1037,12 +1041,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(num_beams, max_length - 1, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
||||
# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||
|
||||
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
||||
if do_sample is False:
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
@@ -1068,41 +1074,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
|
||||
)
|
||||
|
||||
if cur_len < min_length and eos_token_ids is not None:
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||
scores[:, eos_token_id] = -float('inf')
|
||||
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
num_batch_hypotheses = batch_size * num_beams
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
|
||||
for batch_idx in range(batch_size):
|
||||
next_token_logits[
|
||||
batch_idx, banned_tokens[batch_idx]
|
||||
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||
banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
|
||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||
scores[i, banned_tokens] = -float('inf')
|
||||
|
||||
# force eos to be chosen at end of generation for encoder-decoder models
|
||||
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
|
||||
if self.config.is_encoder_decoder:
|
||||
if cur_len == 1:
|
||||
self._force_token_ids_generation(next_token_logits, bos_token_id)
|
||||
if cur_len == max_length - 1:
|
||||
self._force_token_ids_generation(next_token_logits, eos_token_ids)
|
||||
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
|
||||
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size))
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# Top-p/top-k filtering
|
||||
_scores = top_k_top_p_filtering(
|
||||
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together to sample from all beam_idxs
|
||||
_scores = _scores.contiguous().view(
|
||||
batch_size, num_beams * vocab_size
|
||||
@@ -1112,48 +1111,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_tokens = torch.multinomial(
|
||||
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
|
||||
) # (batch_size, num_beams * 2)
|
||||
|
||||
# Compute next scores
|
||||
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
|
||||
|
||||
# sort the sampled vector to make sure that the first num_beams samples are the best
|
||||
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
|
||||
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
|
||||
|
||||
else:
|
||||
# do greedy beam search
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
|
||||
# import math
|
||||
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
|
||||
# scores[:, pad_token_id] = -math.inf # => seems very hacky here
|
||||
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
||||
# if cur_len == 1: # Force BOS to be chosen => also very hacky ... seems also to work without this line
|
||||
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
|
||||
# if cur_len == max_length - 1: # FORCE EOS to be chosen
|
||||
# all_but_eos_mask = torch.tensor(
|
||||
# [x for x in range(vocab_size) if x not in eos_token_ids],
|
||||
# dtype=torch.long,
|
||||
# device=next(self.parameters()).device,
|
||||
# )
|
||||
# scores[:, all_but_eos_mask] = -math.inf
|
||||
|
||||
# if eos_token_ids is not None and cur_len < min_length:
|
||||
# for eos_token_id in eos_token_ids:
|
||||
# scores[:, eos_token_id] = -math.inf # set eos token prob to 0 as is done for attention masks
|
||||
#
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
# if no_repeat_ngram_size > 0:
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
# banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
|
||||
# for batch_idx in range(batch_size):
|
||||
# scores[
|
||||
# batch_idx, banned_tokens[batch_idx]
|
||||
# ] = -math.inf # set eos token prob to 0 as is done for attention masks
|
||||
|
||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
next_scores = next_scores.view(
|
||||
batch_size, num_beams * vocab_size
|
||||
@@ -1164,16 +1130,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
|
||||
|
||||
# next batch beam content
|
||||
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
|
||||
next_batch_beam = []
|
||||
|
||||
# for each sentence
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
# if we are done with this sentence
|
||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len=cur_len
|
||||
)
|
||||
if done[batch_idx]:
|
||||
assert (
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
@@ -1188,15 +1150,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_sent_beam = []
|
||||
|
||||
# next tokens for this sentence
|
||||
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
|
||||
|
||||
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
token_id = idx % vocab_size
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
|
||||
# add to generated hypotheses if end of sentence
|
||||
if eos_token_ids is not None and token_id.item() in eos_token_ids:
|
||||
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
|
||||
# when passed to num_beams hypotheses, continue
|
||||
if i >= num_beams:
|
||||
continue
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[effective_beam_id].clone(), score.item(),
|
||||
)
|
||||
@@ -1208,11 +1173,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if len(next_sent_beam) == num_beams:
|
||||
break
|
||||
|
||||
# Check if were done so that we can save a pad step if all(done)
|
||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len=cur_len
|
||||
)
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
||||
|
||||
# stop when we are done with each sentence
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
@@ -1227,10 +1201,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if past:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# stop when we are done with each sentence
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# extend attention_mask for new generated input
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
||||
@@ -1299,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
return decoded
|
||||
|
||||
# force one of token_ids to be generated by setting prob of all other tokens to 0.
|
||||
def _force_token_ids_generation(self, logits, token_ids):
|
||||
def _force_token_ids_generation(self, scores, token_ids):
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
all_but_token_ids_mask = torch.tensor(
|
||||
@@ -1307,9 +1277,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
assert len(logits.shape) == 2, "logits should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
logits[:, all_but_token_ids_mask] = -10000.0
|
||||
return logits
|
||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
scores[:, all_but_token_ids_mask] = -float('inf')
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
@@ -1326,9 +1295,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
return past
|
||||
|
||||
|
||||
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, step):
|
||||
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
|
||||
# Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
if step + 2 < no_repeat_ngram_size:
|
||||
if cur_len + 1 < no_repeat_ngram_size:
|
||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
@@ -1341,9 +1310,8 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, step):
|
||||
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
start_idx = step + 2 - no_repeat_ngram_size
|
||||
end_idx = step + 1
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:end_idx].tolist())
|
||||
start_idx = cur_len + 1 - no_repeat_ngram_size
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx: cur_len].tolist())
|
||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
|
||||
Reference in New Issue
Block a user