Clean special token init in modeling_....py (#3264)
* make style * fix conflicts
This commit is contained in:
committed by
GitHub
parent
8becb73293
commit
95e00d0808
@@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
|
|||||||
min_length=min_length + 1, # +1 from original because we start at step=1
|
min_length=min_length + 1, # +1 from original because we start at step=1
|
||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size=3,
|
||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
decoder_start_token_id=model.config.eos_token_ids[0],
|
decoder_start_token_id=model.config.eos_token_id,
|
||||||
)
|
)
|
||||||
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
|
||||||
for hypothesis in dec:
|
for hypothesis in dec:
|
||||||
|
|||||||
@@ -223,6 +223,7 @@ if is_torch_available():
|
|||||||
BartForSequenceClassification,
|
BartForSequenceClassification,
|
||||||
BartModel,
|
BartModel,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
|
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
)
|
)
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
|
|||||||
@@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig):
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
classifier_dropout_prob=0.1,
|
classifier_dropout_prob=0.1,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=2,
|
||||||
|
eos_token_id=3,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.embedding_size = embedding_size
|
self.embedding_size = embedding_size
|
||||||
|
|||||||
@@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig):
|
|||||||
activation_dropout=0.0,
|
activation_dropout=0.0,
|
||||||
activation_function="gelu",
|
activation_function="gelu",
|
||||||
vocab_size=50265,
|
vocab_size=50265,
|
||||||
bos_token_id=0,
|
|
||||||
pad_token_id=1,
|
|
||||||
eos_token_ids=[2],
|
|
||||||
d_model=1024,
|
d_model=1024,
|
||||||
encoder_ffn_dim=4096,
|
encoder_ffn_dim=4096,
|
||||||
encoder_layers=12,
|
encoder_layers=12,
|
||||||
@@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig):
|
|||||||
output_past=False,
|
output_past=False,
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
is_encoder_decoder=True,
|
is_encoder_decoder=True,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
**common_kwargs
|
**common_kwargs
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig):
|
|||||||
output_past=output_past,
|
output_past=output_past,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
eos_token_ids=eos_token_ids,
|
eos_token_id=eos_token_id,
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig):
|
|||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
|
pad_token_id=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|||||||
@@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig):
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
qa_dropout=0.1,
|
qa_dropout=0.1,
|
||||||
seq_classif_dropout=0.2,
|
seq_classif_dropout=0.2,
|
||||||
|
pad_token_id=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs, pad_token_id=pad_token_id)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.sinusoidal_pos_embds = sinusoidal_pos_embds
|
self.sinusoidal_pos_embds = sinusoidal_pos_embds
|
||||||
|
|||||||
@@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig):
|
|||||||
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
model_type = "flaubert"
|
model_type = "flaubert"
|
||||||
|
|
||||||
def __init__(self, layerdrop=0.0, pre_norm=False, **kwargs):
|
def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs):
|
||||||
"""Constructs FlaubertConfig.
|
"""Constructs FlaubertConfig.
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
||||||
self.layerdrop = layerdrop
|
self.layerdrop = layerdrop
|
||||||
self.pre_norm = pre_norm
|
self.pre_norm = pre_norm
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig):
|
|||||||
eos_token_id=50256,
|
eos_token_id=50256,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_ctx = n_ctx
|
self.n_ctx = n_ctx
|
||||||
@@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig):
|
|||||||
self.summary_proj_to_labels = summary_proj_to_labels
|
self.summary_proj_to_labels = summary_proj_to_labels
|
||||||
|
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.eos_token_ids = [eos_token_id]
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_position_embeddings(self):
|
def max_position_embeddings(self):
|
||||||
|
|||||||
@@ -66,3 +66,8 @@ class RobertaConfig(BertConfig):
|
|||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
model_type = "roberta"
|
model_type = "roberta"
|
||||||
|
|
||||||
|
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
||||||
|
"""Constructs FlaubertConfig.
|
||||||
|
"""
|
||||||
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|||||||
@@ -77,11 +77,11 @@ class T5Config(PretrainedConfig):
|
|||||||
initializer_factor=1.0,
|
initializer_factor=1.0,
|
||||||
is_encoder_decoder=True,
|
is_encoder_decoder=True,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
eos_token_ids=[1],
|
eos_token_id=1,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
is_encoder_decoder=is_encoder_decoder, **kwargs,
|
pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, **kwargs,
|
||||||
)
|
)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.n_positions = n_positions
|
self.n_positions = n_positions
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.cutoffs = []
|
self.cutoffs = []
|
||||||
@@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
|
|
||||||
self.eos_token_ids = [eos_token_id]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_position_embeddings(self):
|
def max_position_embeddings(self):
|
||||||
return self.tgt_len + self.ext_len + self.mem_len
|
return self.tgt_len + self.ext_len + self.mem_len
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class PretrainedConfig(object):
|
|||||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||||
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
|
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||||
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
|
|||||||
@@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig):
|
|||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
mask_token_id=0,
|
mask_token_id=0,
|
||||||
lang_id=0,
|
lang_id=0,
|
||||||
bos_token_id=0,
|
|
||||||
pad_token_id=2,
|
pad_token_id=2,
|
||||||
|
bos_token_id=0,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""Constructs XLMConfig.
|
"""Constructs XLMConfig.
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.emb_dim = emb_dim
|
self.emb_dim = emb_dim
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
@@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig):
|
|||||||
if "n_words" in kwargs:
|
if "n_words" in kwargs:
|
||||||
self.n_words = kwargs["n_words"]
|
self.n_words = kwargs["n_words"]
|
||||||
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_words(self): # For backward compatibility
|
def n_words(self): # For backward compatibility
|
||||||
return self.vocab_size
|
return self.vocab_size
|
||||||
|
|||||||
@@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
summary_last_dropout=0.1,
|
summary_last_dropout=0.1,
|
||||||
start_n_top=5,
|
start_n_top=5,
|
||||||
end_n_top=5,
|
end_n_top=5,
|
||||||
bos_token_id=1,
|
|
||||||
pad_token_id=5,
|
pad_token_id=5,
|
||||||
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""Constructs XLNetConfig.
|
"""Constructs XLNetConfig.
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_layer = n_layer
|
self.n_layer = n_layer
|
||||||
@@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
|
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.eos_token_ids = [eos_token_id]
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_position_embeddings(self):
|
def max_position_embeddings(self):
|
||||||
|
|||||||
@@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||||
if cur_len == 1:
|
if cur_len == 1:
|
||||||
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
self._force_token_ids_generation(scores, self.config.bos_token_id)
|
||||||
if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None:
|
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||||
self._force_token_ids_generation(scores, self.config.eos_token_ids[0])
|
self._force_token_ids_generation(scores, self.config.eos_token_id)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
|||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
)
|
)
|
||||||
x = outputs[0] # last hidden state
|
x = outputs[0] # last hidden state
|
||||||
eos_mask = input_ids.eq(self.config.eos_token_ids[0])
|
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||||
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
||||||
|
|||||||
@@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
repetition_penalty=None,
|
repetition_penalty=None,
|
||||||
bos_token_id=None,
|
bos_token_id=None,
|
||||||
pad_token_id=None,
|
pad_token_id=None,
|
||||||
eos_token_ids=None,
|
eos_token_id=None,
|
||||||
length_penalty=None,
|
length_penalty=None,
|
||||||
no_repeat_ngram_size=None,
|
no_repeat_ngram_size=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
@@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||||
|
|
||||||
bos_token_id: (`optional`) int
|
bos_token_id: (`optional`) int
|
||||||
Beginning of sentence token if no prompt is provided. Default to 0.
|
Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
|
||||||
|
|
||||||
pad_token_id: (`optional`) int
|
pad_token_id: (`optional`) int
|
||||||
Pad token. Defaults to pad_token_id as defined in the models config.
|
Pad token. Defaults to pad_token_id as defined in the models config.
|
||||||
|
|
||||||
eos_token_ids: (`optional`) int or list of int
|
eos_token_ids: (`optional`) int or list of int
|
||||||
End of sequence token or list of tokens to stop the generation. Default to 0.
|
End of sequence token or list of tokens to stop the generation. Default to 0.
|
||||||
|
|
||||||
length_penalty: (`optional`) float
|
length_penalty: (`optional`) float
|
||||||
Exponential penalty to the length. Default to 1.
|
Exponential penalty to the length. Default to 1.
|
||||||
|
|
||||||
@@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
no_repeat_ngram_size = (
|
no_repeat_ngram_size = (
|
||||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||||
@@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
|
||||||
else:
|
else:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
if isinstance(eos_token_ids, int):
|
|
||||||
eos_token_ids = [eos_token_ids]
|
|
||||||
|
|
||||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||||
@@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
assert pad_token_id is None or (
|
assert pad_token_id is None or (
|
||||||
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||||
), "`pad_token_id` should be a positive integer."
|
), "`pad_token_id` should be a positive integer."
|
||||||
assert (eos_token_ids is None) or (
|
assert (eos_token_id is None) or (
|
||||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
||||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
), "`eos_token_id` should be a positive integer."
|
||||||
assert (
|
assert (
|
||||||
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||||
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||||
@@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
elif attention_mask is None:
|
elif attention_mask is None:
|
||||||
attention_mask = tf.ones_like(input_ids)
|
attention_mask = tf.ones_like(input_ids)
|
||||||
|
|
||||||
if pad_token_id is None and eos_token_ids is not None:
|
if pad_token_id is None and eos_token_id is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
|
||||||
)
|
)
|
||||||
pad_token_id = eos_token_ids[0]
|
pad_token_id = eos_token_id
|
||||||
|
|
||||||
# current position and vocab size
|
# current position and vocab size
|
||||||
cur_len = shape_list(input_ids)[1]
|
cur_len = shape_list(input_ids)[1]
|
||||||
@@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_ids=eos_token_ids,
|
eos_token_id=eos_token_id,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
num_return_sequences=num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
@@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_ids=eos_token_ids,
|
eos_token_id=eos_token_id,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
@@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size,
|
||||||
bos_token_id,
|
bos_token_id,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_id,
|
||||||
decoder_start_token_id,
|
decoder_start_token_id,
|
||||||
batch_size,
|
batch_size,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
@@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_id is not None and cur_len < min_length:
|
||||||
# create eos_token_ids boolean mask
|
# create eos_token_id boolean mask
|
||||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||||
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||||
)
|
)
|
||||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||||
|
|
||||||
@@ -865,16 +864,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
||||||
|
|
||||||
# update generations and finished sentences
|
# update generations and finished sentences
|
||||||
if eos_token_ids is not None:
|
if eos_token_id is not None:
|
||||||
# pad finished sentences if eos_token_ids exist
|
# pad finished sentences if eos_token_id exist
|
||||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||||
else:
|
else:
|
||||||
tokens_to_add = next_token
|
tokens_to_add = next_token
|
||||||
|
|
||||||
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
|
||||||
|
|
||||||
if eos_token_ids is not None:
|
if eos_token_id is not None:
|
||||||
for eos_token_id in eos_token_ids:
|
|
||||||
eos_in_sents = tokens_to_add == eos_token_id
|
eos_in_sents = tokens_to_add == eos_token_id
|
||||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||||
@@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size,
|
||||||
bos_token_id,
|
bos_token_id,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
|
||||||
decoder_start_token_id,
|
decoder_start_token_id,
|
||||||
|
eos_token_id,
|
||||||
batch_size,
|
batch_size,
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
length_penalty,
|
length_penalty,
|
||||||
@@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_id is not None and cur_len < min_length:
|
||||||
# create eos_token_ids boolean mask
|
# create eos_token_id boolean mask
|
||||||
is_token_logit_eos_token = tf.convert_to_tensor(
|
is_token_logit_eos_token = tf.convert_to_tensor(
|
||||||
[True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool
|
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
|
||||||
)
|
)
|
||||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||||
|
|
||||||
@@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
len(generated_hyps[batch_idx]) >= num_beams
|
len(generated_hyps[batch_idx]) >= num_beams
|
||||||
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
||||||
assert (
|
assert (
|
||||||
eos_token_ids is not None and pad_token_id is not None
|
eos_token_id is not None and pad_token_id is not None
|
||||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||||
continue
|
continue
|
||||||
@@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
|
|
||||||
effective_beam_id = batch_idx * num_beams + beam_id
|
effective_beam_id = batch_idx * num_beams + beam_id
|
||||||
# add to generated hypotheses if end of sentence or last iteration
|
# add to generated hypotheses if end of sentence or last iteration
|
||||||
if eos_token_ids is not None and token_id.numpy() in eos_token_ids:
|
if eos_token_id is not None and token_id.numpy() is eos_token_id:
|
||||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||||
if is_beam_token_worse_than_top_num_beams:
|
if is_beam_token_worse_than_top_num_beams:
|
||||||
@@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
if done[batch_idx]:
|
if done[batch_idx]:
|
||||||
continue
|
continue
|
||||||
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
||||||
if eos_token_ids is not None and all(
|
if eos_token_id is not None and all(
|
||||||
(token_id % vocab_size).numpy().item() not in eos_token_ids for token_id in next_tokens[batch_idx]
|
(token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
|
||||||
):
|
):
|
||||||
assert tf.reduce_all(
|
assert tf.reduce_all(
|
||||||
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
|
next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
|
||||||
@@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
if sent_lengths[i] < max_length:
|
if sent_lengths[i] < max_length:
|
||||||
decoded_hypo = tf.where(
|
decoded_hypo = tf.where(
|
||||||
tf.range(max_length) == sent_lengths[i],
|
tf.range(max_length) == sent_lengths[i],
|
||||||
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
|
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
|
||||||
decoded_hypo,
|
decoded_hypo,
|
||||||
)
|
)
|
||||||
decoded_list.append(decoded_hypo)
|
decoded_list.append(decoded_hypo)
|
||||||
|
|||||||
@@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
repetition_penalty=None,
|
repetition_penalty=None,
|
||||||
bos_token_id=None,
|
bos_token_id=None,
|
||||||
pad_token_id=None,
|
pad_token_id=None,
|
||||||
eos_token_ids=None,
|
eos_token_id=None,
|
||||||
length_penalty=None,
|
length_penalty=None,
|
||||||
no_repeat_ngram_size=None,
|
no_repeat_ngram_size=None,
|
||||||
num_return_sequences=None,
|
num_return_sequences=None,
|
||||||
@@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
repetition_penalty: (`optional`) float
|
repetition_penalty: (`optional`) float
|
||||||
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
||||||
|
|
||||||
|
pad_token_id: (`optional`) int
|
||||||
|
Padding token. Default to specicic model pad_token_id or None if it does not exist.
|
||||||
|
|
||||||
bos_token_id: (`optional`) int
|
bos_token_id: (`optional`) int
|
||||||
BOS token. Defaults to bos_token_id as defined in the models config.
|
BOS token. Defaults to bos_token_id as defined in the models config.
|
||||||
|
|
||||||
@@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
||||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
no_repeat_ngram_size = (
|
no_repeat_ngram_size = (
|
||||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||||
@@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||||
else:
|
else:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
if isinstance(eos_token_ids, int):
|
|
||||||
eos_token_ids = [eos_token_ids]
|
|
||||||
|
|
||||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
||||||
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
||||||
@@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
assert pad_token_id is None or (
|
assert pad_token_id is None or (
|
||||||
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||||
), "`pad_token_id` should be a positive integer."
|
), "`pad_token_id` should be a positive integer."
|
||||||
assert (eos_token_ids is None) or (
|
|
||||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
|
||||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
|
||||||
assert (
|
assert (
|
||||||
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
|
||||||
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
|
||||||
|
assert (eos_token_id is None) or (
|
||||||
|
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
||||||
|
), "`eos_token_id` should be a positive integer."
|
||||||
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
||||||
assert (
|
assert (
|
||||||
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
||||||
@@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
elif attention_mask is None:
|
elif attention_mask is None:
|
||||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||||
|
|
||||||
# set pad_token_id to eos_token_ids if not set. Important that this is done after
|
# set pad_token_id to eos_token_id if not set. Important that this is done after
|
||||||
# attention_mask is created
|
# attention_mask is created
|
||||||
if pad_token_id is None and eos_token_ids is not None:
|
if pad_token_id is None and eos_token_id is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
|
||||||
)
|
)
|
||||||
pad_token_id = eos_token_ids[0]
|
pad_token_id = eos_token_id
|
||||||
|
|
||||||
# current position and vocab size
|
# current position and vocab size
|
||||||
vocab_size = self.config.vocab_size
|
vocab_size = self.config.vocab_size
|
||||||
@@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_ids=eos_token_ids,
|
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
num_return_sequences=num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
length_penalty=length_penalty,
|
length_penalty=length_penalty,
|
||||||
@@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_ids=eos_token_ids,
|
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
batch_size=effective_batch_size,
|
batch_size=effective_batch_size,
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size,
|
||||||
bos_token_id,
|
bos_token_id,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_id,
|
||||||
decoder_start_token_id,
|
decoder_start_token_id,
|
||||||
batch_size,
|
batch_size,
|
||||||
encoder_outputs,
|
encoder_outputs,
|
||||||
@@ -1031,8 +1032,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_id is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
|
||||||
next_token_logits[:, eos_token_id] = -float("inf")
|
next_token_logits[:, eos_token_id] = -float("inf")
|
||||||
|
|
||||||
if do_sample:
|
if do_sample:
|
||||||
@@ -1049,16 +1049,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||||
|
|
||||||
# update generations and finished sentences
|
# update generations and finished sentences
|
||||||
if eos_token_ids is not None:
|
if eos_token_id is not None:
|
||||||
# pad finished sentences if eos_token_ids exist
|
# pad finished sentences if eos_token_id exist
|
||||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||||
else:
|
else:
|
||||||
tokens_to_add = next_token
|
tokens_to_add = next_token
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
if eos_token_ids is not None:
|
if eos_token_id is not None:
|
||||||
for eos_token_id in eos_token_ids:
|
|
||||||
eos_in_sents = tokens_to_add == eos_token_id
|
eos_in_sents = tokens_to_add == eos_token_id
|
||||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||||
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
||||||
@@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
no_repeat_ngram_size,
|
no_repeat_ngram_size,
|
||||||
bos_token_id,
|
bos_token_id,
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
eos_token_ids,
|
eos_token_id,
|
||||||
decoder_start_token_id,
|
decoder_start_token_id,
|
||||||
batch_size,
|
batch_size,
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
@@ -1163,8 +1162,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
|
||||||
|
|
||||||
# set eos token prob to zero if min_length is not reached
|
# set eos token prob to zero if min_length is not reached
|
||||||
if eos_token_ids is not None and cur_len < min_length:
|
if eos_token_id is not None and cur_len < min_length:
|
||||||
for eos_token_id in eos_token_ids:
|
|
||||||
scores[:, eos_token_id] = -float("inf")
|
scores[:, eos_token_id] = -float("inf")
|
||||||
|
|
||||||
if no_repeat_ngram_size > 0:
|
if no_repeat_ngram_size > 0:
|
||||||
@@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
len(generated_hyps[batch_idx]) >= num_beams
|
len(generated_hyps[batch_idx]) >= num_beams
|
||||||
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
||||||
assert (
|
assert (
|
||||||
eos_token_ids is not None and pad_token_id is not None
|
eos_token_id is not None and pad_token_id is not None
|
||||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||||
continue
|
continue
|
||||||
@@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
effective_beam_id = batch_idx * num_beams + beam_id
|
effective_beam_id = batch_idx * num_beams + beam_id
|
||||||
|
|
||||||
# add to generated hypotheses if end of sentence
|
# add to generated hypotheses if end of sentence
|
||||||
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
|
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
|
||||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||||
if is_beam_token_worse_than_top_num_beams:
|
if is_beam_token_worse_than_top_num_beams:
|
||||||
@@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
||||||
if eos_token_ids is not None and all(
|
if eos_token_id is not None and all(
|
||||||
(token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx]
|
(token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
|
||||||
):
|
):
|
||||||
assert torch.all(
|
assert torch.all(
|
||||||
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
||||||
@@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
for i, hypo in enumerate(best):
|
for i, hypo in enumerate(best):
|
||||||
decoded[i, : sent_lengths[i]] = hypo
|
decoded[i, : sent_lengths[i]] = hypo
|
||||||
if sent_lengths[i] < max_length:
|
if sent_lengths[i] < max_length:
|
||||||
decoded[i, sent_lengths[i]] = eos_token_ids[0]
|
decoded[i, sent_lengths[i]] = eos_token_id
|
||||||
else:
|
else:
|
||||||
# none of the hypotheses have an eos_token
|
# none of the hypotheses have an eos_token
|
||||||
assert (len(hypo) == max_length for hypo in best)
|
assert (len(hypo) == max_length for hypo in best)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class ModelTester:
|
|||||||
self.hidden_dropout_prob = 0.1
|
self.hidden_dropout_prob = 0.1
|
||||||
self.attention_probs_dropout_prob = 0.1
|
self.attention_probs_dropout_prob = 0.1
|
||||||
self.max_position_embeddings = 20
|
self.max_position_embeddings = 20
|
||||||
self.eos_token_ids = [2]
|
self.eos_token_id = 2
|
||||||
self.pad_token_id = 1
|
self.pad_token_id = 1
|
||||||
self.bos_token_id = 0
|
self.bos_token_id = 0
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@@ -82,7 +82,7 @@ class ModelTester:
|
|||||||
dropout=self.hidden_dropout_prob,
|
dropout=self.hidden_dropout_prob,
|
||||||
attention_dropout=self.attention_probs_dropout_prob,
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
eos_token_ids=self.eos_token_ids,
|
eos_token_id=self.eos_token_id,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
)
|
)
|
||||||
@@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
decoder_ffn_dim=32,
|
decoder_ffn_dim=32,
|
||||||
max_position_embeddings=48,
|
max_position_embeddings=48,
|
||||||
output_past=output_past,
|
output_past=output_past,
|
||||||
eos_token_ids=[2],
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
)
|
)
|
||||||
@@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
decoder_ffn_dim=32,
|
decoder_ffn_dim=32,
|
||||||
max_position_embeddings=48,
|
max_position_embeddings=48,
|
||||||
output_past=True,
|
output_past=True,
|
||||||
eos_token_ids=[2],
|
eos_token_id=2,
|
||||||
pad_token_id=1,
|
pad_token_id=1,
|
||||||
bos_token_id=0,
|
bos_token_id=0,
|
||||||
)
|
)
|
||||||
@@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size=3,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
decoder_start_token_id=hf.config.eos_token_ids[0],
|
decoder_start_token_id=hf.config.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
decoded = [
|
decoded = [
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# type_vocab_size=self.type_vocab_size,
|
# type_vocab_size=self.type_vocab_size,
|
||||||
# initializer_range=self.initializer_range
|
# initializer_range=self.initializer_range
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
eos_token_ids=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# type_vocab_size=self.type_vocab_size,
|
# type_vocab_size=self.type_vocab_size,
|
||||||
# initializer_range=self.initializer_range
|
# initializer_range=self.initializer_range
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
eos_token_ids=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
d_inner=self.d_inner,
|
d_inner=self.d_inner,
|
||||||
div_val=self.div_val,
|
div_val=self.div_val,
|
||||||
n_layer=self.num_hidden_layers,
|
n_layer=self.num_hidden_layers,
|
||||||
eos_token_ids=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
d_inner=self.d_inner,
|
d_inner=self.d_inner,
|
||||||
div_val=self.div_val,
|
div_val=self.div_val,
|
||||||
n_layer=self.num_hidden_layers,
|
n_layer=self.num_hidden_layers,
|
||||||
eos_token_ids=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||||
|
|||||||
Reference in New Issue
Block a user