Clean special token init in modeling_....py (#3264)
* make style * fix conflicts
This commit is contained in:
committed by
GitHub
parent
8becb73293
commit
95e00d0808
@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
repetition_penalty=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_ids=None,
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
num_return_sequences=None,
|
||||
@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
repetition_penalty: (`optional`) float
|
||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||
|
||||
pad_token_id: (`optional`) int
|
||||
Padding token. Default to specicic model pad_token_id or None if it does not exist.
|
||||
|
||||
bos_token_id: (`optional`) int
|
||||
BOS token. Defaults to bos_token_id as defined in the models config.
|
||||
|
||||
@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||
else:
|
||||
batch_size = 1
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert pad_token_id is None or (
|
||||
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||
), "`pad_token_id` should be a positive integer."
|
||||
assert (eos_token_ids is None) or (
|
||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert (
|
||||
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||
assert (eos_token_id is None) or (
|
||||
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
||||
), "`eos_token_id` should be a positive integer."
|
||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||
assert (
|
||||
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
||||
@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
elif attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
# set pad_token_id to eos_token_ids if not set. Important that this is done after
|
||||
# set pad_token_id to eos_token_id if not set. Important that this is done after
|
||||
# attention_mask is created
|
||||
if pad_token_id is None and eos_token_ids is not None:
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
logger.warning(
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
|
||||
)
|
||||
pad_token_id = eos_token_ids[0]
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# current position and vocab size
|
||||
vocab_size = self.config.vocab_size
|
||||
@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
eos_token_id,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
encoder_outputs,
|
||||
@@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
next_token_logits[:, eos_token_id] = -float("inf")
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
next_token_logits[:, eos_token_id] = -float("inf")
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||
|
||||
# update generations and finished sentences
|
||||
if eos_token_ids is not None:
|
||||
# pad finished sentences if eos_token_ids exist
|
||||
if eos_token_id is not None:
|
||||
# pad finished sentences if eos_token_id exist
|
||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||
else:
|
||||
tokens_to_add = next_token
|
||||
|
||||
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
||||
|
||||
if eos_token_ids is not None:
|
||||
for eos_token_id in eos_token_ids:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
||||
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents.mul_((~eos_in_sents).long())
|
||||
if eos_token_id is not None:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
||||
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents.mul_((~eos_in_sents).long())
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if unfinished_sents.max() == 0:
|
||||
@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
eos_token_id,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
@@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
scores[:, eos_token_id] = -float("inf")
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
scores[:, eos_token_id] = -float("inf")
|
||||
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
||||
assert (
|
||||
eos_token_ids is not None and pad_token_id is not None
|
||||
eos_token_id is not None and pad_token_id is not None
|
||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
|
||||
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
continue
|
||||
|
||||
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
||||
if eos_token_ids is not None and all(
|
||||
(token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx]
|
||||
if eos_token_id is not None and all(
|
||||
(token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
|
||||
):
|
||||
assert torch.all(
|
||||
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
||||
@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_ids[0]
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
else:
|
||||
# none of the hypotheses have an eos_token
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
|
||||
Reference in New Issue
Block a user