|
|
|
|
@@ -19,6 +19,8 @@ import logging
|
|
|
|
|
import os
|
|
|
|
|
import typing
|
|
|
|
|
|
|
|
|
|
import ipdb
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
|
@@ -623,6 +625,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
length_penalty=None,
|
|
|
|
|
no_repeat_ngram_size=None,
|
|
|
|
|
num_return_sequences=None,
|
|
|
|
|
attention_mask=None,
|
|
|
|
|
):
|
|
|
|
|
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
|
|
|
|
and beam-search.
|
|
|
|
|
@@ -791,6 +794,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
num_beams >= num_return_sequences
|
|
|
|
|
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
|
|
|
|
|
|
|
|
|
# create attention mask if necessary
|
|
|
|
|
# TODO (PVP): this should later be handled by the forward fn() in each model
|
|
|
|
|
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
|
|
|
|
|
attention_mask = input_ids.ne(pad_token_id).long()
|
|
|
|
|
elif attention_mask is None:
|
|
|
|
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
|
|
|
|
|
|
|
|
|
# set pad_token_id to eos_token_ids if not set. Important that this is done after
|
|
|
|
|
# attention_mask is created
|
|
|
|
|
if pad_token_id is None and eos_token_ids is not None:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
|
|
|
|
@@ -812,15 +824,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
if num_return_sequences > 1 or num_beams > 1:
|
|
|
|
|
input_ids_len = input_ids.shape[-1]
|
|
|
|
|
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
|
|
|
|
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
|
|
|
|
|
|
|
|
|
input_ids = input_ids.contiguous().view(
|
|
|
|
|
effective_batch_size * num_beams, input_ids_len
|
|
|
|
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
|
|
|
|
attention_mask = attention_mask.contiguous().view(
|
|
|
|
|
effective_batch_size * num_beams, input_ids_len
|
|
|
|
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
|
|
|
|
|
|
|
|
|
# TODO (PVP): probably not the best way to check whether model is encoder decoder
|
|
|
|
|
is_encoder_decoder = (
|
|
|
|
|
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
|
|
|
|
|
)
|
|
|
|
|
if is_encoder_decoder:
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
eos_token_id = eos_token_ids[0]
|
|
|
|
|
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
|
|
|
|
|
assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id"
|
|
|
|
|
@@ -828,8 +841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
encoder_inputs = input_ids
|
|
|
|
|
input_ids = torch.full(
|
|
|
|
|
(effective_batch_size * num_beams, 1),
|
|
|
|
|
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
|
|
|
|
|
# eos_token_id,
|
|
|
|
|
bos_token_id,
|
|
|
|
|
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=next(self.parameters()).device,
|
|
|
|
|
)
|
|
|
|
|
@@ -851,6 +865,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
top_p,
|
|
|
|
|
repetition_penalty,
|
|
|
|
|
no_repeat_ngram_size,
|
|
|
|
|
bos_token_id,
|
|
|
|
|
pad_token_id,
|
|
|
|
|
eos_token_ids,
|
|
|
|
|
effective_batch_size,
|
|
|
|
|
@@ -859,6 +874,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
num_beams,
|
|
|
|
|
vocab_size,
|
|
|
|
|
encoder_inputs,
|
|
|
|
|
attention_mask,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
output = self._generate_no_beam_search(
|
|
|
|
|
@@ -876,6 +892,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
eos_token_ids,
|
|
|
|
|
effective_batch_size,
|
|
|
|
|
encoder_inputs,
|
|
|
|
|
attention_mask,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
@@ -896,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
eos_token_ids,
|
|
|
|
|
batch_size,
|
|
|
|
|
encoder_inputs,
|
|
|
|
|
attention_mask
|
|
|
|
|
):
|
|
|
|
|
""" Generate sequences for each example without beam search (num_beams == 1).
|
|
|
|
|
All returned sequence are generated independantly.
|
|
|
|
|
@@ -906,7 +924,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
|
|
|
|
|
past = None
|
|
|
|
|
while cur_len < max_length:
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs)
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
outputs = self(**model_inputs)
|
|
|
|
|
next_token_logits = outputs[0][:, -1, :]
|
|
|
|
|
@@ -922,7 +940,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
|
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
|
|
|
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
|
|
|
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
|
|
|
|
|
for batch_idx in range(batch_size):
|
|
|
|
|
next_token_logits[
|
|
|
|
|
batch_idx, banned_tokens[batch_idx]
|
|
|
|
|
@@ -968,6 +986,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
if unfinished_sents.max() == 0:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# extend attention_mask for new generated input
|
|
|
|
|
if self.config.is_encoder_decoder is False:
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
|
|
|
|
|
|
|
|
|
cur_len = cur_len + 1
|
|
|
|
|
|
|
|
|
|
# if there are different sentences lengths in the batch, some batches have to be padded
|
|
|
|
|
@@ -995,6 +1017,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
top_p,
|
|
|
|
|
repetition_penalty,
|
|
|
|
|
no_repeat_ngram_size,
|
|
|
|
|
bos_token_id,
|
|
|
|
|
pad_token_id,
|
|
|
|
|
eos_token_ids,
|
|
|
|
|
batch_size,
|
|
|
|
|
@@ -1003,12 +1026,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
num_beams,
|
|
|
|
|
vocab_size,
|
|
|
|
|
encoder_inputs,
|
|
|
|
|
attention_mask,
|
|
|
|
|
):
|
|
|
|
|
""" Generate sequences for each example with beam search.
|
|
|
|
|
"""
|
|
|
|
|
is_encoder_decoder = (
|
|
|
|
|
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# generated hypotheses
|
|
|
|
|
generated_hyps = [
|
|
|
|
|
@@ -1029,7 +1050,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
done = [False for _ in range(batch_size)]
|
|
|
|
|
|
|
|
|
|
while cur_len < max_length:
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs)
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
|
|
|
|
|
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
|
|
|
|
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
|
|
|
|
|
|
|
|
|
@@ -1043,20 +1064,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cur_len < min_length and eos_token_ids is not None:
|
|
|
|
|
for eos_token_id in eos_token_ids:
|
|
|
|
|
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
|
|
|
|
|
|
|
|
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
|
|
|
if no_repeat_ngram_size > 0:
|
|
|
|
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
|
|
|
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
|
|
|
|
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
|
|
|
|
|
for batch_idx in range(batch_size):
|
|
|
|
|
next_token_logits[
|
|
|
|
|
batch_idx, banned_tokens[batch_idx]
|
|
|
|
|
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
|
|
|
|
|
|
|
|
|
if eos_token_ids is not None and cur_len < min_length:
|
|
|
|
|
for eos_token_id in eos_token_ids:
|
|
|
|
|
next_token_logits[
|
|
|
|
|
:, eos_token_id
|
|
|
|
|
] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
|
|
|
|
# force eos to be chosen at end of generation for encoder-decoder models
|
|
|
|
|
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
|
|
|
|
|
if cur_len == 1:
|
|
|
|
|
self._force_token_ids_generation(next_token_logits, bos_token_id)
|
|
|
|
|
if cur_len == max_length - 1:
|
|
|
|
|
self._force_token_ids_generation(next_token_logits, eos_token_ids)
|
|
|
|
|
|
|
|
|
|
if do_sample:
|
|
|
|
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
|
|
|
@@ -1091,19 +1119,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
# do greedy beam search
|
|
|
|
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
|
|
|
|
|
|
|
|
if is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
|
|
|
|
|
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
|
|
|
|
|
# scores[:, pad_token_id] = -math.inf => seems very hacky here
|
|
|
|
|
# if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
|
|
|
|
|
# import math
|
|
|
|
|
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
|
|
|
|
|
# scores[:, pad_token_id] = -math.inf # => seems very hacky here
|
|
|
|
|
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
|
|
|
|
# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
|
|
|
|
|
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
|
|
|
|
|
if cur_len == max_length - 1: # FORCE EOS to be chosen
|
|
|
|
|
all_but_eos_mask = torch.tensor(
|
|
|
|
|
[x for x in range(vocab_size) if x not in eos_token_ids],
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=next(self.parameters()).device,
|
|
|
|
|
)
|
|
|
|
|
scores[:, all_but_eos_mask] = -10000.0
|
|
|
|
|
# if cur_len == 1: # Force BOS to be chosen => also very hacky ... seems also to work without this line
|
|
|
|
|
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
|
|
|
|
|
# if cur_len == max_length - 1: # FORCE EOS to be chosen
|
|
|
|
|
# all_but_eos_mask = torch.tensor(
|
|
|
|
|
# [x for x in range(vocab_size) if x not in eos_token_ids],
|
|
|
|
|
# dtype=torch.long,
|
|
|
|
|
# device=next(self.parameters()).device,
|
|
|
|
|
# )
|
|
|
|
|
# scores[:, all_but_eos_mask] = -math.inf
|
|
|
|
|
|
|
|
|
|
# if eos_token_ids is not None and cur_len < min_length:
|
|
|
|
|
# for eos_token_id in eos_token_ids:
|
|
|
|
|
# scores[:, eos_token_id] = -math.inf # set eos token prob to 0 as is done for attention masks
|
|
|
|
|
#
|
|
|
|
|
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
|
|
|
|
# if no_repeat_ngram_size > 0:
|
|
|
|
|
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
|
|
|
|
# banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
|
|
|
|
|
# for batch_idx in range(batch_size):
|
|
|
|
|
# scores[
|
|
|
|
|
# batch_idx, banned_tokens[batch_idx]
|
|
|
|
|
# ] = -math.inf # set eos token prob to 0 as is done for attention masks
|
|
|
|
|
|
|
|
|
|
assert scores.size() == (batch_size * num_beams, vocab_size)
|
|
|
|
|
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
|
|
|
|
@@ -1126,7 +1168,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
|
|
|
|
|
# if we are done with this sentence
|
|
|
|
|
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
|
|
|
|
next_scores[batch_idx].max().item()
|
|
|
|
|
next_scores[batch_idx].max().item(), cur_len=cur_len
|
|
|
|
|
)
|
|
|
|
|
if done[batch_idx]:
|
|
|
|
|
assert (
|
|
|
|
|
@@ -1185,6 +1227,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
if all(done):
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# extend attention_mask for new generated input
|
|
|
|
|
if self.config.is_encoder_decoder is False:
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
|
|
|
|
|
|
|
|
|
|
# update current length
|
|
|
|
|
cur_len = cur_len + 1
|
|
|
|
|
|
|
|
|
|
@@ -1243,11 +1289,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
|
assert (len(hypo) == max_length for hypo in best)
|
|
|
|
|
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
|
|
|
|
|
|
|
|
|
if is_encoder_decoder:
|
|
|
|
|
if self.config.is_encoder_decoder:
|
|
|
|
|
# do not return first <BOS> token
|
|
|
|
|
return decoded[:, 1:]
|
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
|
|
# force one of token_ids to be generated by setting prob of all other tokens to 0.
|
|
|
|
|
def _force_token_ids_generation(self, logits, token_ids):
|
|
|
|
|
if isinstance(token_ids, int):
|
|
|
|
|
token_ids = [token_ids]
|
|
|
|
|
all_but_token_ids_mask = torch.tensor(
|
|
|
|
|
[x for x in range(self.config.vocab_size) if x not in token_ids],
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
device=next(self.parameters()).device,
|
|
|
|
|
)
|
|
|
|
|
assert len(logits.shape) == 2, "logits should be of rank 2 with shape: [batch_size, vocab_size]"
|
|
|
|
|
logits[:, all_but_token_ids_mask] = -10000.0
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _reorder_cache(past, beam_idx):
|
|
|
|
|
reordered_past = []
|
|
|
|
|
|