Clean special token init in modeling_....py (#3264)
* make style * fix conflicts
This commit is contained in:
committed by
GitHub
parent
8becb73293
commit
95e00d0808
@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
||||
min_length=min_length + 1, # +1 from original because we start at step=1
|
||||
no_repeat_ngram_size=3,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=model.config.eos_token_ids[0],
|
||||
decoder_start_token_id=model.config.eos_token_id,
|
||||
)
|
||||
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
||||
for hypothesis in dec:
|
||||
|
||||
@@ -223,6 +223,7 @@ if is_torch_available():
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
from .modeling_roberta import (
|
||||
RobertaForMaskedLM,
|
||||
|
||||
@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
classifier_dropout_prob=0.1,
|
||||
pad_token_id=0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=3,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
|
||||
@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig):
|
||||
activation_dropout=0.0,
|
||||
activation_function="gelu",
|
||||
vocab_size=50265,
|
||||
bos_token_id=0,
|
||||
pad_token_id=1,
|
||||
eos_token_ids=[2],
|
||||
d_model=1024,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_layers=12,
|
||||
@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig):
|
||||
output_past=False,
|
||||
num_labels=3,
|
||||
is_encoder_decoder=True,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig):
|
||||
output_past=output_past,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig):
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
qa_dropout=0.1,
|
||||
seq_classif_dropout=0.2,
|
||||
pad_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(**kwargs, pad_token_id=pad_token_id)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.sinusoidal_pos_embds = sinusoidal_pos_embds
|
||||
|
||||
@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig):
|
||||
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "flaubert"
|
||||
|
||||
def __init__(self, layerdrop=0.0, pre_norm=False, **kwargs):
|
||||
def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs):
|
||||
"""Constructs FlaubertConfig.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
||||
self.layerdrop = layerdrop
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig):
|
||||
eos_token_id=50256,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.n_ctx = n_ctx
|
||||
@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig):
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_ids = [eos_token_id]
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
|
||||
@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig):
|
||||
"""
|
||||
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "roberta"
|
||||
|
||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
||||
"""Constructs FlaubertConfig.
|
||||
"""
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig):
|
||||
initializer_factor=1.0,
|
||||
is_encoder_decoder=True,
|
||||
pad_token_id=0,
|
||||
eos_token_ids=[1],
|
||||
eos_token_id=1,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
is_encoder_decoder=is_encoder_decoder, **kwargs,
|
||||
pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, **kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
|
||||
@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
eos_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.cutoffs = []
|
||||
@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
self.init_std = init_std
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
|
||||
self.eos_token_ids = [eos_token_id]
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.tgt_len + self.ext_len + self.mem_len
|
||||
|
||||
@@ -80,7 +80,7 @@ class PretrainedConfig(object):
|
||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
|
||||
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
|
||||
@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig):
|
||||
end_n_top=5,
|
||||
mask_token_id=0,
|
||||
lang_id=0,
|
||||
bos_token_id=0,
|
||||
pad_token_id=2,
|
||||
bos_token_id=0,
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs XLMConfig.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.emb_dim = emb_dim
|
||||
self.n_layers = n_layers
|
||||
@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig):
|
||||
if "n_words" in kwargs:
|
||||
self.n_words = kwargs["n_words"]
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
@property
|
||||
def n_words(self): # For backward compatibility
|
||||
return self.vocab_size
|
||||
|
||||
@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig):
|
||||
summary_last_dropout=0.1,
|
||||
start_n_top=5,
|
||||
end_n_top=5,
|
||||
bos_token_id=1,
|
||||
pad_token_id=5,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
"""Constructs XLNetConfig.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.d_model = d_model
|
||||
self.n_layer = n_layer
|
||||
@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig):
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_token_ids = [eos_token_id]
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
|
||||
@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
if cur_len == 1:
|
||||
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
||||
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||
return scores
|
||||
|
||||
@staticmethod
|
||||
@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
x = outputs[0] # last hidden state
|
||||
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
||||
|
||||
@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
repetition_penalty=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_ids=None,
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
num_return_sequences=None,
|
||||
@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||
|
||||
bos_token_id: (`optional`) int
|
||||
Beginning of sentence token if no prompt is provided. Default to 0.
|
||||
Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
|
||||
|
||||
pad_token_id: (`optional`) int
|
||||
Pad token. Defaults to pad_token_id as defined in the models config.
|
||||
|
||||
eos_token_ids: (`optional`) int or list of int
|
||||
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||
|
||||
length_penalty: (`optional`) float
|
||||
Exponential penalty to the length. Default to 1.
|
||||
|
||||
@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
||||
else:
|
||||
batch_size = 1
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
assert pad_token_id is None or (
|
||||
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||
), "`pad_token_id` should be a positive integer."
|
||||
assert (eos_token_ids is None) or (
|
||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert (eos_token_id is None) or (
|
||||
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
||||
), "`eos_token_id` should be a positive integer."
|
||||
assert (
|
||||
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||
@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
elif attention_mask is None:
|
||||
attention_mask = tf.ones_like(input_ids)
|
||||
|
||||
if pad_token_id is None and eos_token_ids is not None:
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
logger.warning(
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
|
||||
)
|
||||
pad_token_id = eos_token_ids[0]
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = shape_list(input_ids)[1]
|
||||
@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
eos_token_id=eos_token_id,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
vocab_size=vocab_size,
|
||||
@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
eos_token_id,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
vocab_size,
|
||||
@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
# create eos_token_ids boolean mask
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
# create eos_token_id boolean mask
|
||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
||||
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||
|
||||
@@ -865,28 +864,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
||||
|
||||
# update generations and finished sentences
|
||||
if eos_token_ids is not None:
|
||||
# pad finished sentences if eos_token_ids exist
|
||||
if eos_token_id is not None:
|
||||
# pad finished sentences if eos_token_id exist
|
||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||
else:
|
||||
tokens_to_add = next_token
|
||||
|
||||
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
||||
|
||||
if eos_token_ids is not None:
|
||||
for eos_token_id in eos_token_ids:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
|
||||
)
|
||||
sent_lengths = (
|
||||
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
|
||||
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
|
||||
)
|
||||
if eos_token_id is not None:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||
unfinished_sents, tf.cast(eos_in_sents, tf.int32)
|
||||
)
|
||||
sent_lengths = (
|
||||
sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
|
||||
+ cur_len * is_sents_unfinished_and_token_to_add_is_eos
|
||||
)
|
||||
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if tf.math.reduce_max(unfinished_sents) == 0:
|
||||
@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
decoder_start_token_id,
|
||||
eos_token_id,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
# create eos_token_ids boolean mask
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
# create eos_token_id boolean mask
|
||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
||||
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||
|
||||
@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
||||
assert (
|
||||
eos_token_ids is not None and pad_token_id is not None
|
||||
eos_token_id is not None and pad_token_id is not None
|
||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
|
||||
if eos_token_id is not None and token_id.numpy() is eos_token_id:
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if done[batch_idx]:
|
||||
continue
|
||||
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
||||
if eos_token_ids is not None and all(
|
||||
(token_id % vocab_size).numpy().item() not in eos_token_ids for token_id in next_tokens[batch_idx]
|
||||
if eos_token_id is not None and all(
|
||||
(token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
|
||||
):
|
||||
assert tf.reduce_all(
|
||||
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
|
||||
@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded_hypo = tf.where(
|
||||
tf.range(max_length) == sent_lengths[i],
|
||||
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
|
||||
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
|
||||
decoded_hypo,
|
||||
)
|
||||
decoded_list.append(decoded_hypo)
|
||||
|
||||
@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
repetition_penalty=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_ids=None,
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
num_return_sequences=None,
|
||||
@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
repetition_penalty: (`optional`) float
|
||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||
|
||||
pad_token_id: (`optional`) int
|
||||
Padding token. Default to specicic model pad_token_id or None if it does not exist.
|
||||
|
||||
bos_token_id: (`optional`) int
|
||||
BOS token. Defaults to bos_token_id as defined in the models config.
|
||||
|
||||
@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||
else:
|
||||
batch_size = 1
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||
@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert pad_token_id is None or (
|
||||
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||
), "`pad_token_id` should be a positive integer."
|
||||
assert (eos_token_ids is None) or (
|
||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert (
|
||||
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||
assert (eos_token_id is None) or (
|
||||
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
||||
), "`eos_token_id` should be a positive integer."
|
||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||
assert (
|
||||
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
||||
@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
elif attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
# set pad_token_id to eos_token_ids if not set. Important that this is done after
|
||||
# set pad_token_id to eos_token_id if not set. Important that this is done after
|
||||
# attention_mask is created
|
||||
if pad_token_id is None and eos_token_ids is not None:
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
logger.warning(
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
|
||||
)
|
||||
pad_token_id = eos_token_ids[0]
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# current position and vocab size
|
||||
vocab_size = self.config.vocab_size
|
||||
@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bos_token_id=bos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_ids=eos_token_ids,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=effective_batch_size,
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
eos_token_id,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
encoder_outputs,
|
||||
@@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
next_token_logits[:, eos_token_id] = -float("inf")
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
next_token_logits[:, eos_token_id] = -float("inf")
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
@@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||
|
||||
# update generations and finished sentences
|
||||
if eos_token_ids is not None:
|
||||
# pad finished sentences if eos_token_ids exist
|
||||
if eos_token_id is not None:
|
||||
# pad finished sentences if eos_token_id exist
|
||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||
else:
|
||||
tokens_to_add = next_token
|
||||
|
||||
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
||||
|
||||
if eos_token_ids is not None:
|
||||
for eos_token_id in eos_token_ids:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
||||
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents.mul_((~eos_in_sents).long())
|
||||
if eos_token_id is not None:
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
||||
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents.mul_((~eos_in_sents).long())
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if unfinished_sents.max() == 0:
|
||||
@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
no_repeat_ngram_size,
|
||||
bos_token_id,
|
||||
pad_token_id,
|
||||
eos_token_ids,
|
||||
eos_token_id,
|
||||
decoder_start_token_id,
|
||||
batch_size,
|
||||
num_return_sequences,
|
||||
@@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_ids is not None and cur_len < min_length:
|
||||
for eos_token_id in eos_token_ids:
|
||||
scores[:, eos_token_id] = -float("inf")
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
scores[:, eos_token_id] = -float("inf")
|
||||
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
||||
assert (
|
||||
eos_token_ids is not None and pad_token_id is not None
|
||||
eos_token_id is not None and pad_token_id is not None
|
||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
|
||||
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
continue
|
||||
|
||||
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
||||
if eos_token_ids is not None and all(
|
||||
(token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx]
|
||||
if eos_token_id is not None and all(
|
||||
(token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
|
||||
):
|
||||
assert torch.all(
|
||||
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
||||
@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_ids[0]
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
else:
|
||||
# none of the hypotheses have an eos_token
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
|
||||
@@ -61,7 +61,7 @@ class ModelTester:
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 20
|
||||
self.eos_token_ids = [2]
|
||||
self.eos_token_id = 2
|
||||
self.pad_token_id = 1
|
||||
self.bos_token_id = 0
|
||||
torch.manual_seed(0)
|
||||
@@ -82,7 +82,7 @@ class ModelTester:
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
eos_token_ids=self.eos_token_ids,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=output_past,
|
||||
eos_token_ids=[2],
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
)
|
||||
@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=True,
|
||||
eos_token_ids=[2],
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
)
|
||||
@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
no_repeat_ngram_size=3,
|
||||
do_sample=False,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=hf.config.eos_token_ids[0],
|
||||
decoder_start_token_id=hf.config.eos_token_id,
|
||||
)
|
||||
|
||||
decoded = [
|
||||
|
||||
@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
@@ -107,7 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
d_inner=self.d_inner,
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
||||
@@ -103,7 +103,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
d_inner=self.d_inner,
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
||||
Reference in New Issue
Block a user