|
|
|
@@ -609,6 +609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
input_ids=None,
|
|
|
|
input_ids=None,
|
|
|
|
max_length=None,
|
|
|
|
max_length=None,
|
|
|
|
|
|
|
|
min_length=None,
|
|
|
|
do_sample=True,
|
|
|
|
do_sample=True,
|
|
|
|
num_beams=None,
|
|
|
|
num_beams=None,
|
|
|
|
temperature=None,
|
|
|
|
temperature=None,
|
|
|
|
@@ -713,6 +714,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
max_length = max_length if max_length is not None else self.config.max_length
|
|
|
|
max_length = max_length if max_length is not None else self.config.max_length
|
|
|
|
|
|
|
|
min_length = min_length if min_length is not None else self.config.min_length
|
|
|
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
|
|
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
|
|
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
|
|
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
|
|
|
temperature = temperature if temperature is not None else self.config.temperature
|
|
|
|
temperature = temperature if temperature is not None else self.config.temperature
|
|
|
|
@@ -735,6 +737,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
eos_token_ids = [eos_token_ids]
|
|
|
|
eos_token_ids = [eos_token_ids]
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
|
|
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
|
|
|
|
|
|
|
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
|
|
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
|
|
|
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
|
|
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
|
|
|
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
|
|
|
assert temperature > 0, "`temperature` should be strictly positive."
|
|
|
|
assert temperature > 0, "`temperature` should be strictly positive."
|
|
|
|
@@ -824,12 +827,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
encoder_inputs = input_ids
|
|
|
|
encoder_inputs = input_ids
|
|
|
|
input_ids = torch.full(
|
|
|
|
input_ids = torch.full(
|
|
|
|
(effective_batch_size * num_beams, 1),
|
|
|
|
(effective_batch_size * num_beams, 1),
|
|
|
|
# eos_token_id,
|
|
|
|
# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no?
|
|
|
|
bos_token_id,
|
|
|
|
bos_token_id,
|
|
|
|
dtype=torch.long,
|
|
|
|
dtype=torch.long,
|
|
|
|
device=next(self.parameters()).device,
|
|
|
|
device=next(self.parameters()).device,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
cur_len = 0
|
|
|
|
cur_len = 1
|
|
|
|
self.model.decoder.generation_mode = True
|
|
|
|
self.model.decoder.generation_mode = True
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
encoder_inputs = None
|
|
|
|
encoder_inputs = None
|
|
|
|
@@ -840,6 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
input_ids,
|
|
|
|
input_ids,
|
|
|
|
cur_len,
|
|
|
|
cur_len,
|
|
|
|
max_length,
|
|
|
|
max_length,
|
|
|
|
|
|
|
|
min_length,
|
|
|
|
do_sample,
|
|
|
|
do_sample,
|
|
|
|
temperature,
|
|
|
|
temperature,
|
|
|
|
top_k,
|
|
|
|
top_k,
|
|
|
|
@@ -859,6 +863,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
input_ids,
|
|
|
|
input_ids,
|
|
|
|
cur_len,
|
|
|
|
cur_len,
|
|
|
|
max_length,
|
|
|
|
max_length,
|
|
|
|
|
|
|
|
min_length,
|
|
|
|
do_sample,
|
|
|
|
do_sample,
|
|
|
|
temperature,
|
|
|
|
temperature,
|
|
|
|
top_k,
|
|
|
|
top_k,
|
|
|
|
@@ -877,6 +882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
input_ids,
|
|
|
|
input_ids,
|
|
|
|
cur_len,
|
|
|
|
cur_len,
|
|
|
|
max_length,
|
|
|
|
max_length,
|
|
|
|
|
|
|
|
min_length,
|
|
|
|
do_sample,
|
|
|
|
do_sample,
|
|
|
|
temperature,
|
|
|
|
temperature,
|
|
|
|
top_k,
|
|
|
|
top_k,
|
|
|
|
@@ -911,6 +917,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
|
if repetition_penalty != 1.0:
|
|
|
|
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
|
|
|
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if eos_token_ids is not None and cur_len < min_length:
|
|
|
|
|
|
|
|
for eos_token_id in eos_token_ids:
|
|
|
|
|
|
|
|
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
|
|
|
|
|
|
|
|
|
|
|
if do_sample:
|
|
|
|
if do_sample:
|
|
|
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
|
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
|
|
if temperature != 1.0:
|
|
|
|
if temperature != 1.0:
|
|
|
|
@@ -965,6 +975,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
input_ids,
|
|
|
|
input_ids,
|
|
|
|
cur_len,
|
|
|
|
cur_len,
|
|
|
|
max_length,
|
|
|
|
max_length,
|
|
|
|
|
|
|
|
min_length,
|
|
|
|
do_sample,
|
|
|
|
do_sample,
|
|
|
|
temperature,
|
|
|
|
temperature,
|
|
|
|
top_k,
|
|
|
|
top_k,
|
|
|
|
@@ -1022,6 +1033,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
|
|
|
|
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if eos_token_ids is not None and cur_len < min_length:
|
|
|
|
|
|
|
|
for eos_token_id in eos_token_ids:
|
|
|
|
|
|
|
|
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
|
|
|
|
|
|
|
|
|
|
|
if do_sample:
|
|
|
|
if do_sample:
|
|
|
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
|
|
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
|
|
if temperature != 1.0:
|
|
|
|
if temperature != 1.0:
|
|
|
|
@@ -1056,18 +1071,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
|
|
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
if is_encoder_decoder: # TODO(PVP) to be refactored later
|
|
|
|
if is_encoder_decoder: # TODO(PVP) to be refactored later
|
|
|
|
import math
|
|
|
|
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
|
|
|
|
# scores[scores != scores] = -math.inf # block nans
|
|
|
|
# scores[:, pad_token_id] = -math.inf => seems very hacky here
|
|
|
|
# scores[:, pad_token_id] = -math.inf
|
|
|
|
|
|
|
|
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
|
|
|
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
|
|
|
# if cur_len == 0: # Force BOS to be chosen
|
|
|
|
# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
|
|
|
|
# scores[:, self.config.bos_token_id + 1 :] = -math.inf # TODO(PVP) should not use bos_token_id here
|
|
|
|
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
|
|
|
|
# elif cur_len < min_len: # Prevent EOS from being chosen TODO: for the moment don't think about min_len
|
|
|
|
if cur_len == max_length - 1: # FORCE EOS to be chosen
|
|
|
|
# scores[:, eos_token_ids[0]] = -math.inf
|
|
|
|
all_but_eos_mask = torch.tensor([x for x in range(vocab_size) if x not in eos_token_ids], dtype=torch.long, device=next(self.parameters()).device)
|
|
|
|
# elif cur_len == max_length: # FORCE EOS to be chosen
|
|
|
|
scores[:, all_but_eos_mask] = -10000.0
|
|
|
|
if cur_len == max_length: # FORCE EOS to be chosen
|
|
|
|
|
|
|
|
scores[:, :eos_token_ids[0]] = -math.inf
|
|
|
|
|
|
|
|
scores[:, eos_token_ids[0] + 1 :] = -math.inf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert scores.size() == (batch_size * num_beams, vocab_size)
|
|
|
|
assert scores.size() == (batch_size * num_beams, vocab_size)
|
|
|
|
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
|
|
|
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
|
|
|
@@ -1194,7 +1205,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|
|
|
# shorter batches are filled with pad_token
|
|
|
|
# shorter batches are filled with pad_token
|
|
|
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
|
|
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
|
|
|
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
|
|
|
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
|
|
|
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1)
|
|
|
|
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
|
|
|
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
|
|
|
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
# fill with hypothesis and eos_token_id if necessary
|
|
|
|
# fill with hypothesis and eos_token_id if necessary
|
|
|
|
|