refactored code a bit and made more generic
This commit is contained in:
@@ -609,6 +609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
self,
|
||||
input_ids=None,
|
||||
max_length=None,
|
||||
min_length=None,
|
||||
do_sample=True,
|
||||
num_beams=None,
|
||||
temperature=None,
|
||||
@@ -713,6 +714,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
)
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
@@ -735,6 +737,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictly positive."
|
||||
@@ -824,12 +827,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_inputs = input_ids
|
||||
input_ids = torch.full(
|
||||
(effective_batch_size * num_beams, 1),
|
||||
# eos_token_id,
|
||||
# eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no?
|
||||
bos_token_id,
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
cur_len = 0
|
||||
cur_len = 1
|
||||
self.model.decoder.generation_mode = True
|
||||
else:
|
||||
encoder_inputs = None
|
||||
@@ -840,6 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
@@ -859,6 +863,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
@@ -877,6 +882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
@@ -911,6 +917,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if repetition_penalty != 1.0:
|
||||
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
|
||||
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
@@ -965,6 +975,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids,
|
||||
cur_len,
|
||||
max_length,
|
||||
min_length,
|
||||
do_sample,
|
||||
temperature,
|
||||
top_k,
|
||||
@@ -1022,6 +1033,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
|
||||
)
|
||||
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
@@ -1056,18 +1071,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
if is_encoder_decoder: # TODO(PVP) to be refactored later
|
||||
import math
|
||||
# scores[scores != scores] = -math.inf # block nans
|
||||
# scores[:, pad_token_id] = -math.inf
|
||||
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
|
||||
# scores[:, pad_token_id] = -math.inf => seems very hacky here
|
||||
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
|
||||
# if cur_len == 0: # Force BOS to be chosen
|
||||
# scores[:, self.config.bos_token_id + 1 :] = -math.inf # TODO(PVP) should not use bos_token_id here
|
||||
# elif cur_len < min_len: # Prevent EOS from being chosen TODO: for the moment don't think about min_len
|
||||
# scores[:, eos_token_ids[0]] = -math.inf
|
||||
# elif cur_len == max_length: # FORCE EOS to be chosen
|
||||
if cur_len == max_length: # FORCE EOS to be chosen
|
||||
scores[:, :eos_token_ids[0]] = -math.inf
|
||||
scores[:, eos_token_ids[0] + 1 :] = -math.inf
|
||||
# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
|
||||
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
|
||||
if cur_len == max_length - 1: # FORCE EOS to be chosen
|
||||
all_but_eos_mask = torch.tensor([x for x in range(vocab_size) if x not in eos_token_ids], dtype=torch.long, device=next(self.parameters()).device)
|
||||
scores[:, all_but_eos_mask] = -10000.0
|
||||
|
||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||
@@ -1194,7 +1205,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# shorter batches are filled with pad_token
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1)
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
|
||||
Reference in New Issue
Block a user