fix conflicts

This commit is contained in:
patrickvonplaten
2020-03-08 14:26:08 +01:00
committed by Patrick von Platen
parent 77e6775065
commit c62444da39
3 changed files with 69 additions and 437 deletions

View File

@@ -14,7 +14,6 @@
# limitations under the License.
"""PyTorch BART model, ported from the fairseq repo."""
import logging
import math
import random
from typing import Dict, List, Optional, Tuple
@@ -24,7 +23,7 @@ from torch import Tensor, nn
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import BeamHypotheses, PreTrainedModel, create_position_ids_from_input_ids
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
logger = logging.getLogger(__name__)
@@ -942,22 +941,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
@staticmethod
def prepare_inputs_for_generation_bart(input_ids, past, decoder_input_ids, attention_mask):
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
encoder_outputs, decoder_cached_states = past
return {
"input_ids": input_ids, # ignored after first pass
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
}
@staticmethod
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs, attention_mask):
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
@@ -973,6 +957,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
"attention_mask": attention_mask
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1:
self._force_token_ids_generation(scores, self.config.eos_token_ids)
return scores
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_cached_states) = past
@@ -994,273 +985,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
def get_output_embeddings(self):
return self.lm_head
@torch.no_grad()
def generate_bart(
self,
input_ids,
attention_mask=None,
max_length=20,
num_beams=1,
repetition_penalty=1.0,
length_penalty=1.0,
num_return_sequences=1,
min_len=0,
no_repeat_ngram_size=0,
):
r""" Generates summaries using the lm-head and greedy beam search
Adapted in part from Facebook's `XLM beam search code`_ and `Fairseq beam search code`_.
.. _`XLM beam search code`:
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
.. _`Fairseq beam search code`:
https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py
Parameters:
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape `(1,)`.
max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
num_beams: (`optional`) int
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1.
num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1.
min_len: (`optional`) int
Returns:
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
sequence_length is <= max_length (examples can finish early)
Examples::
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
"""
bos_token_id = self.config.bos_token_id
pad_token_id = self.config.pad_token_id
eos_token_id = self.config.eos_token_id
batch_size, cur_len = input_ids.shape
assert input_ids is not None
assert self.config.output_past, "Generating with bart requires instantiating a config with output_past=True"
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(pad_token_id, int)
assert bos_token_id == 0, "configurable bos_token_id not yet supported"
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
), "`num_return_sequences` should be a positive integer."
# current position and vocab size
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
if num_return_sequences != 1:
# Expand input to num return sequences
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_return_sequences, cur_len)
input_ids = input_ids.contiguous().view(
batch_size * num_return_sequences, cur_len
) # shape: (batch_size * num_return_sequences, cur_len)
batch_size *= num_return_sequences
# Below here somewhat similar to PretrainedModel._generate_beam_search
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
if attention_mask is not None:
attention_mask = (
attention_mask.unsqueeze(1)
.expand(batch_size, num_beams, cur_len)
.contiguous()
.view(batch_size * num_beams, cur_len)
) # RESHAPE
# generated hypotheses
finalized_hyps = [ # they end in EOS and we wont work on them more!
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=True) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 # avoid ties in first step
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# decoder tokens
prev_output_tokens = input_ids.new(batch_size * num_beams, 1).long().fill_(-1)
prev_output_tokens[:, 0] = 2 # HARDCODED EOS, which will be removed at the end.
decoder_cache = None
done = [False for _ in range(batch_size)] # done sentences
self.model.decoder.generation_mode = True # tells decoder not to use causal mask
for step in range(max_length + 1):
decoder_input_ids = prev_output_tokens.clone()
model_inputs = self.prepare_inputs_for_generation_bart(
input_ids, decoder_cache, decoder_input_ids, attention_mask,
)
outputs = self(**model_inputs)
lprobs = F.log_softmax(outputs[0][:, -1, :], dim=-1)
lprobs[lprobs != lprobs] = -math.inf # block nans
lprobs[:, pad_token_id] = -math.inf
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
if step == 0: # Force BOS to be chosen
lprobs[:, bos_token_id + 1 :] = -math.inf
elif step < min_len: # Prevent EOS from being chosen
lprobs[:, eos_token_id] = -math.inf
elif step == max_length: # FORCE EOS to be chosen
lprobs[:, :eos_token_id] = -math.inf
lprobs[:, eos_token_id + 1 :] = -math.inf
assert self._do_output_past(outputs)
decoder_cache = outputs[1]
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
num_hypos = batch_size * num_beams
if no_repeat_ngram_size > 0: # copied from fairseq
# for each sentence, calculate a list of banned tokens to prevent repetitively generating the same ngrams
banned_tokens = self.calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step)
# then set their probabilities tof -inf
for idx in range(num_hypos):
lprobs[idx, banned_tokens[idx]] = -math.inf
assert lprobs.size() == (batch_size * num_beams, vocab_size)
_scores = lprobs + beam_scores[:, None].expand_as(lprobs) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis across beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
# Take the best 2 x beam_size predictions for each example, we'll choose the first beam_size of these which don't predict eos to continue with.
next_scores, next_words = torch.topk(_scores, 2 * num_beams)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
# list of (batch_size * num_beams)
next_batch_beam = [] # Tuple(next score, next word, current position in the batch)
for batch_idx in range(batch_size):
# if we are done with this sentence (because we can't improve)
if done[batch_idx]: # then pad all associated hypotheses
assert (
len(finalized_hyps[batch_idx]) >= num_beams
), "Example can only be done if at least {} beams have been generated".format(num_beams)
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# Otherwise generate some next word choices
next_sent_beam = []
# add next words for this sentence
for i, (idx, score) in enumerate(zip(next_words[batch_idx], next_scores[batch_idx])):
beam_id = idx // vocab_size
word_id = idx % vocab_size
assert prev_output_tokens.shape[1] == (step + 1)
if word_id.item() == eos_token_id:
if i >= num_beams:
continue
finalized_hyps[batch_idx].add(
prev_output_tokens[batch_idx * num_beams + beam_id].clone(), score.item(),
)
else:
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
if len(next_sent_beam) == num_beams: # TODO(SS): can we delete this?
break
# Check if were done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or finalized_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=step + 1,
)
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order decoder inputs to [beam_idx]
prev_output_tokens = prev_output_tokens[beam_idx]
prev_output_tokens = torch.cat([prev_output_tokens, beam_words.unsqueeze(1)], dim=-1)
# re-order internal states
decoder_cache = self._reorder_cache(decoder_cache, beam_idx)
for batch_idx in range(batch_size):
# Add all open beam hypothesis to generated_hyps
if done[batch_idx]:
continue
offset = batch_idx * num_beams
for i in range(num_beams):
score = beam_scores[offset + i]
final_tokens = prev_output_tokens[offset + i]
finalized_hyps[batch_idx].add(final_tokens, score.item())
# select the best hypotheses
sent_lengths = input_ids.new(batch_size)
best = []
for i, hypotheses in enumerate(finalized_hyps):
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
sent_lengths[i] = len(best_hyp)
best.append(best_hyp)
# shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item():
# TODO(SS): decoded = torch.rnn.utils.pad_sequence(best, batch_first=True, padding_value=pad_token_id)
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) # TODO(SS): same as step?
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id
else:
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded[:, 1:] # get rid of starting EOS
@staticmethod
def calc_banned_tokens(prev_output_tokens, num_hypos, no_repeat_ngram_size, step):
"""Copied from fairseq for no_repeat_ngram in beam_search"""
# TODO(SS): this can go on parent if there is demand
if step + 2 < no_repeat_ngram_size:
return [
[] for _ in range(num_hypos)
] # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
gen_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_output_tokens[idx].tolist()
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
k = tuple(ngram[:-1])
gen_ngrams[idx][k] = gen_ngrams[idx].get(k, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
"""Before decoding the next token, prevent decoding of ngrams that have already appeared"""
ngram_index = tuple(prev_output_tokens[hypo_idx, step + 2 - no_repeat_ngram_size : step + 1].tolist())
return gen_ngrams[hypo_idx].get(ngram_index, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,

View File

@@ -587,6 +587,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def prepare_scores_for_generation(self, scores, **kwargs):
return scores
def _do_output_past(self, outputs):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False)
@@ -940,20 +943,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
for batch_idx in range(batch_size):
next_token_logits[
batch_idx, banned_tokens[batch_idx]
] = -10000.0 # set eos token prob to 0 as is done for attention masks
] = -float('inf')
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[
:, eos_token_id
] = -10000.0 # set eos token prob to 0 as is done for attention masks
] = -float('inf')
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@@ -1037,12 +1041,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length - 1, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
# BeamHypotheses(num_beams, max_length - 2, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
if do_sample is False:
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
@@ -1068,41 +1074,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
)
if cur_len < min_length and eos_token_ids is not None:
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
scores[:, eos_token_id] = -float('inf')
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
num_batch_hypotheses = batch_size * num_beams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
for batch_idx in range(batch_size):
next_token_logits[
batch_idx, banned_tokens[batch_idx]
] = -10000.0 # set eos token prob to 0 as is done for attention masks
banned_batch_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float('inf')
# force eos to be chosen at end of generation for encoder-decoder models
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
if self.config.is_encoder_decoder:
if cur_len == 1:
self._force_token_ids_generation(next_token_logits, bos_token_id)
if cur_len == max_length - 1:
self._force_token_ids_generation(next_token_logits, eos_token_ids)
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(scores.shape, (batch_size * num_beams, vocab_size))
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# Top-p/top-k filtering
_scores = top_k_top_p_filtering(
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together to sample from all beam_idxs
_scores = _scores.contiguous().view(
batch_size, num_beams * vocab_size
@@ -1112,48 +1111,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_tokens = torch.multinomial(
F.softmax(_scores, dim=-1), num_samples=2 * num_beams
) # (batch_size, num_beams * 2)
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
# sort the sampled vector to make sure that the first num_beams samples are the best
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
else:
# do greedy beam search
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
# if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
# import math
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
# scores[:, pad_token_id] = -math.inf # => seems very hacky here
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
# if cur_len == 1: # Force BOS to be chosen => also very hacky ... seems also to work without this line
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
# if cur_len == max_length - 1: # FORCE EOS to be chosen
# all_but_eos_mask = torch.tensor(
# [x for x in range(vocab_size) if x not in eos_token_ids],
# dtype=torch.long,
# device=next(self.parameters()).device,
# )
# scores[:, all_but_eos_mask] = -math.inf
# if eos_token_ids is not None and cur_len < min_length:
# for eos_token_id in eos_token_ids:
# scores[:, eos_token_id] = -math.inf # set eos token prob to 0 as is done for attention masks
#
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# if no_repeat_ngram_size > 0:
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
# banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
# for batch_idx in range(batch_size):
# scores[
# batch_idx, banned_tokens[batch_idx]
# ] = -math.inf # set eos token prob to 0 as is done for attention masks
assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = next_scores.view(
batch_size, num_beams * vocab_size
@@ -1164,16 +1130,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=cur_len
)
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
@@ -1188,15 +1150,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_sent_beam = []
# next tokens for this sentence
for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]):
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])):
# get beam and word IDs
beam_id = idx // vocab_size
token_id = idx % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence
if eos_token_ids is not None and token_id.item() in eos_token_ids:
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
# when passed to num_beams hypotheses, continue
if i >= num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), score.item(),
)
@@ -1208,11 +1173,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if len(next_sent_beam) == num_beams:
break
# Check if were done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=cur_len
)
# update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
# stop when we are done with each sentence
if all(done):
break
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
@@ -1227,10 +1201,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if past:
past = self._reorder_cache(past, beam_idx)
# stop when we are done with each sentence
if all(done):
break
# extend attention_mask for new generated input
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
@@ -1299,7 +1269,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.
def _force_token_ids_generation(self, logits, token_ids):
def _force_token_ids_generation(self, scores, token_ids):
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
@@ -1307,9 +1277,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(logits.shape) == 2, "logits should be of rank 2 with shape: [batch_size, vocab_size]"
logits[:, all_but_token_ids_mask] = -10000.0
return logits
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float('inf')
@staticmethod
def _reorder_cache(past, beam_idx):
@@ -1326,9 +1295,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return past
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, step):
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search"""
if step + 2 < no_repeat_ngram_size:
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
@@ -1341,9 +1310,8 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, step):
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = step + 2 - no_repeat_ngram_size
end_idx = step + 1
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:end_idx].tolist())
start_idx = cur_len + 1 - no_repeat_ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx: cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]