Clean special token init in modeling_....py (#3264)

* make style

* fix conflicts
This commit is contained in:
Patrick von Platen
2020-03-20 21:41:04 +01:00
committed by GitHub
parent 8becb73293
commit 95e00d0808
22 changed files with 117 additions and 115 deletions

View File

@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
min_length=min_length + 1, # +1 from original because we start at step=1 min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
early_stopping=True, early_stopping=True,
decoder_start_token_id=model.config.eos_token_ids[0], decoder_start_token_id=model.config.eos_token_id,
) )
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec: for hypothesis in dec:

View File

@@ -223,6 +223,7 @@ if is_torch_available():
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartForConditionalGeneration, BartForConditionalGeneration,
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
) )
from .modeling_roberta import ( from .modeling_roberta import (
RobertaForMaskedLM, RobertaForMaskedLM,

View File

@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
classifier_dropout_prob=0.1, classifier_dropout_prob=0.1,
pad_token_id=0,
bos_token_id=2,
eos_token_id=3,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding_size = embedding_size self.embedding_size = embedding_size

View File

@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig):
activation_dropout=0.0, activation_dropout=0.0,
activation_function="gelu", activation_function="gelu",
vocab_size=50265, vocab_size=50265,
bos_token_id=0,
pad_token_id=1,
eos_token_ids=[2],
d_model=1024, d_model=1024,
encoder_ffn_dim=4096, encoder_ffn_dim=4096,
encoder_layers=12, encoder_layers=12,
@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig):
output_past=False, output_past=False,
num_labels=3, num_labels=3,
is_encoder_decoder=True, is_encoder_decoder=True,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
**common_kwargs **common_kwargs
): ):
r""" r"""
@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig):
output_past=output_past, output_past=output_past,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_ids=eos_token_ids, eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder, is_encoder_decoder=is_encoder_decoder,
**common_kwargs, **common_kwargs,
) )

View File

@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig):
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size

View File

@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
qa_dropout=0.1, qa_dropout=0.1,
seq_classif_dropout=0.2, seq_classif_dropout=0.2,
pad_token_id=0,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs, pad_token_id=pad_token_id)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.sinusoidal_pos_embds = sinusoidal_pos_embds self.sinusoidal_pos_embds = sinusoidal_pos_embds

View File

@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig):
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "flaubert" model_type = "flaubert"
def __init__(self, layerdrop=0.0, pre_norm=False, **kwargs): def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs):
"""Constructs FlaubertConfig. """Constructs FlaubertConfig.
""" """
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
self.layerdrop = layerdrop self.layerdrop = layerdrop
self.pre_norm = pre_norm self.pre_norm = pre_norm

View File

@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig):
eos_token_id=50256, eos_token_id=50256,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_ctx = n_ctx self.n_ctx = n_ctx
@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig):
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_ids = [eos_token_id] self.eos_token_id = eos_token_id
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):

View File

@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig):
""" """
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "roberta" model_type = "roberta"
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
"""Constructs FlaubertConfig.
"""
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

View File

@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0, initializer_factor=1.0,
is_encoder_decoder=True, is_encoder_decoder=True,
pad_token_id=0, pad_token_id=0,
eos_token_ids=[1], eos_token_id=1,
**kwargs **kwargs
): ):
super().__init__( super().__init__(
is_encoder_decoder=is_encoder_decoder, **kwargs, pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, **kwargs,
) )
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions self.n_positions = n_positions

View File

@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig):
eos_token_id=0, eos_token_id=0,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.cutoffs = [] self.cutoffs = []
@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.eos_token_ids = [eos_token_id]
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len return self.tgt_len + self.ext_len + self.mem_len

View File

@@ -80,7 +80,7 @@ class PretrainedConfig(object):
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop("bos_token_id", None) self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_ids = kwargs.pop("eos_token_ids", None) self.eos_token_id = kwargs.pop("eos_token_id", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)

View File

@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig):
end_n_top=5, end_n_top=5,
mask_token_id=0, mask_token_id=0,
lang_id=0, lang_id=0,
bos_token_id=0,
pad_token_id=2, pad_token_id=2,
bos_token_id=0,
**kwargs **kwargs
): ):
"""Constructs XLMConfig. """Constructs XLMConfig.
""" """
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.emb_dim = emb_dim self.emb_dim = emb_dim
self.n_layers = n_layers self.n_layers = n_layers
@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig):
if "n_words" in kwargs: if "n_words" in kwargs:
self.n_words = kwargs["n_words"] self.n_words = kwargs["n_words"]
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
@property @property
def n_words(self): # For backward compatibility def n_words(self): # For backward compatibility
return self.vocab_size return self.vocab_size

View File

@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig):
summary_last_dropout=0.1, summary_last_dropout=0.1,
start_n_top=5, start_n_top=5,
end_n_top=5, end_n_top=5,
bos_token_id=1,
pad_token_id=5, pad_token_id=5,
bos_token_id=1,
eos_token_id=2, eos_token_id=2,
**kwargs **kwargs
): ):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
""" """
super().__init__(**kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.d_model = d_model self.d_model = d_model
self.n_layer = n_layer self.n_layer = n_layer
@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.eos_token_ids = [eos_token_id] self.eos_token_id = eos_token_id
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):

View File

@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_scores_for_generation(self, scores, cur_len, max_length):
if cur_len == 1: if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id) self._force_token_ids_generation(scores, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_ids[0]) self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores return scores
@staticmethod @staticmethod
@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
) )
x = outputs[0] # last hidden state x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_ids[0]) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]

View File

@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty=None, repetition_penalty=None,
bos_token_id=None, bos_token_id=None,
pad_token_id=None, pad_token_id=None,
eos_token_ids=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
num_return_sequences=None, num_return_sequences=None,
@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
bos_token_id: (`optional`) int bos_token_id: (`optional`) int
Beginning of sentence token if no prompt is provided. Default to 0. Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
pad_token_id: (`optional`) int pad_token_id: (`optional`) int
Pad token. Defaults to pad_token_id as defined in the models config. Pad token. Defaults to pad_token_id as defined in the models config.
eos_token_ids: (`optional`) int or list of int eos_token_ids: (`optional`) int or list of int
End of sequence token or list of tokens to stop the generation. Default to 0. End of sequence token or list of tokens to stop the generation. Default to 0.
length_penalty: (`optional`) float length_penalty: (`optional`) float
Exponential penalty to the length. Default to 1. Exponential penalty to the length. Default to 1.
@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = ( no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
else: else:
batch_size = 1 batch_size = 1
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert pad_token_id is None or ( assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0) isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer." ), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or ( assert (eos_token_id is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." ), "`eos_token_id` should be a positive integer."
assert ( assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
elif attention_mask is None: elif attention_mask is None:
attention_mask = tf.ones_like(input_ids) attention_mask = tf.ones_like(input_ids)
if pad_token_id is None and eos_token_ids is not None: if pad_token_id is None and eos_token_id is not None:
logger.warning( logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
) )
pad_token_id = eos_token_ids[0] pad_token_id = eos_token_id
# current position and vocab size # current position and vocab size
cur_len = shape_list(input_ids)[1] cur_len = shape_list(input_ids)[1]
@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids, eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids, eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
vocab_size=vocab_size, vocab_size=vocab_size,
@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_id,
decoder_start_token_id, decoder_start_token_id,
batch_size, batch_size,
vocab_size, vocab_size,
@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
) )
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
# create eos_token_ids boolean mask # create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor( is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
) )
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
@@ -865,16 +864,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32) next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
# update generations and finished sentences # update generations and finished sentences
if eos_token_ids is not None: if eos_token_id is not None:
# pad finished sentences if eos_token_ids exist # pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else: else:
tokens_to_add = next_token tokens_to_add = next_token
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1) input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
if eos_token_ids is not None: if eos_token_id is not None:
for eos_token_id in eos_token_ids:
eos_in_sents = tokens_to_add == eos_token_id eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids,
decoder_start_token_id, decoder_start_token_id,
eos_token_id,
batch_size, batch_size,
num_return_sequences, num_return_sequences,
length_penalty, length_penalty,
@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
# create eos_token_ids boolean mask # create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor( is_token_logit_eos_token = tf.convert_to_tensor(
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
) )
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams) ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert ( assert (
eos_token_ids is not None and pad_token_id is not None eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue continue
@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration # add to generated hypotheses if end of sentence or last iteration
if eos_token_ids is not None and token_id.numpy() in eos_token_ids: if eos_token_id is not None and token_id.numpy() is eos_token_id:
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if done[batch_idx]: if done[batch_idx]:
continue continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done # test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_ids is not None and all( if eos_token_id is not None and all(
(token_id % vocab_size).numpy().item() not in eos_token_ids for token_id in next_tokens[batch_idx] (token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
): ):
assert tf.reduce_all( assert tf.reduce_all(
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx] next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if sent_lengths[i] < max_length: if sent_lengths[i] < max_length:
decoded_hypo = tf.where( decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i], tf.range(max_length) == sent_lengths[i],
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo, decoded_hypo,
) )
decoded_list.append(decoded_hypo) decoded_list.append(decoded_hypo)

View File

@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty=None, repetition_penalty=None,
bos_token_id=None, bos_token_id=None,
pad_token_id=None, pad_token_id=None,
eos_token_ids=None, eos_token_id=None,
length_penalty=None, length_penalty=None,
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
num_return_sequences=None, num_return_sequences=None,
@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty: (`optional`) float repetition_penalty: (`optional`) float
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
pad_token_id: (`optional`) int
Padding token. Default to specicic model pad_token_id or None if it does not exist.
bos_token_id: (`optional`) int bos_token_id: (`optional`) int
BOS token. Defaults to bos_token_id as defined in the models config. BOS token. Defaults to bos_token_id as defined in the models config.
@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = ( no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size = input_ids.shape[0] # overriden by the input batch_size batch_size = input_ids.shape[0] # overriden by the input batch_size
else: else:
batch_size = 1 batch_size = 1
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert pad_token_id is None or ( assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0) isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer." ), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert ( assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert length_penalty > 0, "`length_penalty` should be strictly positive." assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert ( assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
elif attention_mask is None: elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape) attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_ids if not set. Important that this is done after # set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created # attention_mask is created
if pad_token_id is None and eos_token_ids is not None: if pad_token_id is None and eos_token_id is not None:
logger.warning( logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
) )
pad_token_id = eos_token_ids[0] pad_token_id = eos_token_id
# current position and vocab size # current position and vocab size
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
length_penalty=length_penalty, length_penalty=length_penalty,
@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_ids=eos_token_ids,
decoder_start_token_id=decoder_start_token_id, decoder_start_token_id=decoder_start_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size, batch_size=effective_batch_size,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_id,
decoder_start_token_id, decoder_start_token_id,
batch_size, batch_size,
encoder_outputs, encoder_outputs,
@@ -1031,8 +1032,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[:, eos_token_id] = -float("inf") next_token_logits[:, eos_token_id] = -float("inf")
if do_sample: if do_sample:
@@ -1049,16 +1049,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token = torch.argmax(next_token_logits, dim=-1) next_token = torch.argmax(next_token_logits, dim=-1)
# update generations and finished sentences # update generations and finished sentences
if eos_token_ids is not None: if eos_token_id is not None:
# pad finished sentences if eos_token_ids exist # pad finished sentences if eos_token_id exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else: else:
tokens_to_add = next_token tokens_to_add = next_token
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
if eos_token_ids is not None: if eos_token_id is not None:
for eos_token_id in eos_token_ids:
eos_in_sents = tokens_to_add == eos_token_id eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size, no_repeat_ngram_size,
bos_token_id, bos_token_id,
pad_token_id, pad_token_id,
eos_token_ids, eos_token_id,
decoder_start_token_id, decoder_start_token_id,
batch_size, batch_size,
num_return_sequences, num_return_sequences,
@@ -1163,8 +1162,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length) scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
scores[:, eos_token_id] = -float("inf") scores[:, eos_token_id] = -float("inf")
if no_repeat_ngram_size > 0: if no_repeat_ngram_size > 0:
@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams) ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert ( assert (
eos_token_ids is not None and pad_token_id is not None eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue continue
@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence # add to generated hypotheses if end of sentence
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids): if (eos_token_id is not None) and (token_id.item() is eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
continue continue
# test that beam scores match previously calculated scores if not eos and batch_idx not done # test that beam scores match previously calculated scores if not eos and batch_idx not done
if eos_token_ids is not None and all( if eos_token_id is not None and all(
(token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx] (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
): ):
assert torch.all( assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length: if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_ids[0] decoded[i, sent_lengths[i]] = eos_token_id
else: else:
# none of the hypotheses have an eos_token # none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best) assert (len(hypo) == max_length for hypo in best)

View File

@@ -61,7 +61,7 @@ class ModelTester:
self.hidden_dropout_prob = 0.1 self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20 self.max_position_embeddings = 20
self.eos_token_ids = [2] self.eos_token_id = 2
self.pad_token_id = 1 self.pad_token_id = 1
self.bos_token_id = 0 self.bos_token_id = 0
torch.manual_seed(0) torch.manual_seed(0)
@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob, dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob, attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
eos_token_ids=self.eos_token_ids, eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
output_past=output_past, output_past=output_past,
eos_token_ids=[2], eos_token_id=2,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
) )
@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
output_past=True, output_past=True,
eos_token_ids=[2], eos_token_id=2,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
) )
@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase):
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
do_sample=False, do_sample=False,
early_stopping=True, early_stopping=True,
decoder_start_token_id=hf.config.eos_token_ids[0], decoder_start_token_id=hf.config.eos_token_id,
) )
decoded = [ decoded = [

View File

@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_ids=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)

View File

@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_ids=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)

View File

@@ -107,7 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
d_inner=self.d_inner, d_inner=self.d_inner,
div_val=self.div_val, div_val=self.div_val,
n_layer=self.num_hidden_layers, n_layer=self.num_hidden_layers,
eos_token_ids=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)

View File

@@ -103,7 +103,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
d_inner=self.d_inner, d_inner=self.d_inner,
div_val=self.div_val, div_val=self.div_val,
n_layer=self.num_hidden_layers, n_layer=self.num_hidden_layers,
eos_token_ids=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)