From 95e00d08082d6e87e6c61d1f78b401f4ec337317 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 20 Mar 2020 21:41:04 +0100 Subject: [PATCH] Clean special token init in modeling_....py (#3264) * make style * fix conflicts --- examples/summarization/bart/evaluate_cnn.py | 2 +- src/transformers/__init__.py | 1 + src/transformers/configuration_albert.py | 5 +- src/transformers/configuration_bart.py | 8 +- src/transformers/configuration_bert.py | 3 +- src/transformers/configuration_distilbert.py | 3 +- src/transformers/configuration_flaubert.py | 4 +- src/transformers/configuration_gpt2.py | 4 +- src/transformers/configuration_roberta.py | 5 ++ src/transformers/configuration_t5.py | 4 +- src/transformers/configuration_transfo_xl.py | 4 +- src/transformers/configuration_utils.py | 2 +- src/transformers/configuration_xlm.py | 7 +- src/transformers/configuration_xlnet.py | 6 +- src/transformers/modeling_bart.py | 6 +- src/transformers/modeling_tf_utils.py | 80 ++++++++++---------- src/transformers/modeling_utils.py | 70 +++++++++-------- tests/test_modeling_bart.py | 10 +-- tests/test_modeling_gpt2.py | 2 +- tests/test_modeling_tf_gpt2.py | 2 +- tests/test_modeling_tf_transfo_xl.py | 2 +- tests/test_modeling_transfo_xl.py | 2 +- 22 files changed, 117 insertions(+), 115 deletions(-) diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index b6a2eb7bdf..5c69dc921f 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -35,7 +35,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): min_length=min_length + 1, # +1 from original because we start at step=1 no_repeat_ngram_size=3, early_stopping=True, - decoder_start_token_id=model.config.eos_token_ids[0], + decoder_start_token_id=model.config.eos_token_id, ) dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] for hypothesis in dec: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 82016686e2..8ea0354dab 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -223,6 +223,7 @@ if is_torch_available(): BartForSequenceClassification, BartModel, BartForConditionalGeneration, + BART_PRETRAINED_MODEL_ARCHIVE_MAP, ) from .modeling_roberta import ( RobertaForMaskedLM, diff --git a/src/transformers/configuration_albert.py b/src/transformers/configuration_albert.py index 3419753cb1..59ba0fe071 100644 --- a/src/transformers/configuration_albert.py +++ b/src/transformers/configuration_albert.py @@ -124,9 +124,12 @@ class AlbertConfig(PretrainedConfig): initializer_range=0.02, layer_norm_eps=1e-12, classifier_dropout_prob=0.1, + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, **kwargs ): - super().__init__(**kwargs) + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.embedding_size = embedding_size diff --git a/src/transformers/configuration_bart.py b/src/transformers/configuration_bart.py index 3bb26ead68..077e257a58 100644 --- a/src/transformers/configuration_bart.py +++ b/src/transformers/configuration_bart.py @@ -41,9 +41,6 @@ class BartConfig(PretrainedConfig): activation_dropout=0.0, activation_function="gelu", vocab_size=50265, - bos_token_id=0, - pad_token_id=1, - eos_token_ids=[2], d_model=1024, encoder_ffn_dim=4096, encoder_layers=12, @@ -61,6 +58,9 @@ class BartConfig(PretrainedConfig): output_past=False, num_labels=3, is_encoder_decoder=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, **common_kwargs ): r""" @@ -74,7 +74,7 @@ class BartConfig(PretrainedConfig): output_past=output_past, pad_token_id=pad_token_id, bos_token_id=bos_token_id, - eos_token_ids=eos_token_ids, + eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, **common_kwargs, ) diff --git a/src/transformers/configuration_bert.py b/src/transformers/configuration_bert.py index d668d04cb8..caa1055738 100644 --- a/src/transformers/configuration_bert.py +++ b/src/transformers/configuration_bert.py @@ -124,9 +124,10 @@ class BertConfig(PretrainedConfig): type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, + pad_token_id=0, **kwargs ): - super().__init__(**kwargs) + super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size diff --git a/src/transformers/configuration_distilbert.py b/src/transformers/configuration_distilbert.py index 217dc7eb03..1d528297bb 100644 --- a/src/transformers/configuration_distilbert.py +++ b/src/transformers/configuration_distilbert.py @@ -113,9 +113,10 @@ class DistilBertConfig(PretrainedConfig): initializer_range=0.02, qa_dropout=0.1, seq_classif_dropout=0.2, + pad_token_id=0, **kwargs ): - super().__init__(**kwargs) + super().__init__(**kwargs, pad_token_id=pad_token_id) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.sinusoidal_pos_embds = sinusoidal_pos_embds diff --git a/src/transformers/configuration_flaubert.py b/src/transformers/configuration_flaubert.py index 0c9860cbed..c807f63d38 100644 --- a/src/transformers/configuration_flaubert.py +++ b/src/transformers/configuration_flaubert.py @@ -145,9 +145,9 @@ class FlaubertConfig(XLMConfig): pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP model_type = "flaubert" - def __init__(self, layerdrop=0.0, pre_norm=False, **kwargs): + def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs): """Constructs FlaubertConfig. """ - super().__init__(**kwargs) + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) self.layerdrop = layerdrop self.pre_norm = pre_norm diff --git a/src/transformers/configuration_gpt2.py b/src/transformers/configuration_gpt2.py index 1f2352a6c9..0e85a91821 100644 --- a/src/transformers/configuration_gpt2.py +++ b/src/transformers/configuration_gpt2.py @@ -142,7 +142,7 @@ class GPT2Config(PretrainedConfig): eos_token_id=50256, **kwargs ): - super().__init__(**kwargs) + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.n_ctx = n_ctx @@ -163,7 +163,7 @@ class GPT2Config(PretrainedConfig): self.summary_proj_to_labels = summary_proj_to_labels self.bos_token_id = bos_token_id - self.eos_token_ids = [eos_token_id] + self.eos_token_id = eos_token_id @property def max_position_embeddings(self): diff --git a/src/transformers/configuration_roberta.py b/src/transformers/configuration_roberta.py index 655fe03b71..03bdfe3031 100644 --- a/src/transformers/configuration_roberta.py +++ b/src/transformers/configuration_roberta.py @@ -66,3 +66,8 @@ class RobertaConfig(BertConfig): """ pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP model_type = "roberta" + + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): + """Constructs FlaubertConfig. + """ + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/configuration_t5.py b/src/transformers/configuration_t5.py index a86bb2f3bf..6f1ab56fb3 100644 --- a/src/transformers/configuration_t5.py +++ b/src/transformers/configuration_t5.py @@ -77,11 +77,11 @@ class T5Config(PretrainedConfig): initializer_factor=1.0, is_encoder_decoder=True, pad_token_id=0, - eos_token_ids=[1], + eos_token_id=1, **kwargs ): super().__init__( - is_encoder_decoder=is_encoder_decoder, **kwargs, + pad_token_id=pad_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, **kwargs, ) self.vocab_size = vocab_size self.n_positions = n_positions diff --git a/src/transformers/configuration_transfo_xl.py b/src/transformers/configuration_transfo_xl.py index acbc1f00c6..2e484d327c 100644 --- a/src/transformers/configuration_transfo_xl.py +++ b/src/transformers/configuration_transfo_xl.py @@ -152,7 +152,7 @@ class TransfoXLConfig(PretrainedConfig): eos_token_id=0, **kwargs ): - super().__init__(**kwargs) + super().__init__(eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.cutoffs = [] @@ -187,8 +187,6 @@ class TransfoXLConfig(PretrainedConfig): self.init_std = init_std self.layer_norm_epsilon = layer_norm_epsilon - self.eos_token_ids = [eos_token_id] - @property def max_position_embeddings(self): return self.tgt_len + self.ext_len + self.mem_len diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 020f3ea49d..83185bc10f 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -80,7 +80,7 @@ class PretrainedConfig(object): self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.bos_token_id = kwargs.pop("bos_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None) - self.eos_token_ids = kwargs.pop("eos_token_ids", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) self.num_return_sequences = kwargs.pop("num_return_sequences", 1) diff --git a/src/transformers/configuration_xlm.py b/src/transformers/configuration_xlm.py index ec764b6867..73fbd99a19 100644 --- a/src/transformers/configuration_xlm.py +++ b/src/transformers/configuration_xlm.py @@ -194,13 +194,13 @@ class XLMConfig(PretrainedConfig): end_n_top=5, mask_token_id=0, lang_id=0, - bos_token_id=0, pad_token_id=2, + bos_token_id=0, **kwargs ): """Constructs XLMConfig. """ - super().__init__(**kwargs) + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs) self.vocab_size = vocab_size self.emb_dim = emb_dim self.n_layers = n_layers @@ -236,9 +236,6 @@ class XLMConfig(PretrainedConfig): if "n_words" in kwargs: self.n_words = kwargs["n_words"] - self.bos_token_id = bos_token_id - self.pad_token_id = pad_token_id - @property def n_words(self): # For backward compatibility return self.vocab_size diff --git a/src/transformers/configuration_xlnet.py b/src/transformers/configuration_xlnet.py index 129e315043..109d74fb25 100644 --- a/src/transformers/configuration_xlnet.py +++ b/src/transformers/configuration_xlnet.py @@ -155,14 +155,14 @@ class XLNetConfig(PretrainedConfig): summary_last_dropout=0.1, start_n_top=5, end_n_top=5, - bos_token_id=1, pad_token_id=5, + bos_token_id=1, eos_token_id=2, **kwargs ): """Constructs XLNetConfig. """ - super().__init__(**kwargs) + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.d_model = d_model self.n_layer = n_layer @@ -193,7 +193,7 @@ class XLNetConfig(PretrainedConfig): self.bos_token_id = bos_token_id self.pad_token_id = pad_token_id - self.eos_token_ids = [eos_token_id] + self.eos_token_id = eos_token_id @property def max_position_embeddings(self): diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index c74b08ef1b..ac1764de8b 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -906,8 +906,8 @@ class BartForConditionalGeneration(PretrainedBartModel): def prepare_scores_for_generation(self, scores, cur_len, max_length): if cur_len == 1: self._force_token_ids_generation(scores, self.config.bos_token_id) - if cur_len == max_length - 1 and self.config.eos_token_ids[0] is not None: - self._force_token_ids_generation(scores, self.config.eos_token_ids[0]) + if cur_len == max_length - 1 and self.config.eos_token_id is not None: + self._force_token_ids_generation(scores, self.config.eos_token_id) return scores @staticmethod @@ -1003,7 +1003,7 @@ class BartForSequenceClassification(PretrainedBartModel): encoder_outputs=encoder_outputs, ) x = outputs[0] # last hidden state - eos_mask = input_ids.eq(self.config.eos_token_ids[0]) + eos_mask = input_ids.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index a9767ccfa5..07d722f945 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -469,7 +469,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): repetition_penalty=None, bos_token_id=None, pad_token_id=None, - eos_token_ids=None, + eos_token_id=None, length_penalty=None, no_repeat_ngram_size=None, num_return_sequences=None, @@ -518,13 +518,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. bos_token_id: (`optional`) int - Beginning of sentence token if no prompt is provided. Default to 0. + Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist. pad_token_id: (`optional`) int Pad token. Defaults to pad_token_id as defined in the models config. eos_token_ids: (`optional`) int or list of int End of sequence token or list of tokens to stop the generation. Default to 0. + length_penalty: (`optional`) float Exponential penalty to the length. Default to 1. @@ -601,7 +602,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size @@ -615,8 +616,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): batch_size = shape_list(input_ids)[0] # overriden by the input batch_size else: batch_size = 1 - if isinstance(eos_token_ids, int): - eos_token_ids = [eos_token_ids] assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." @@ -633,9 +632,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): assert pad_token_id is None or ( isinstance(pad_token_id, int) and (pad_token_id >= 0) ), "`pad_token_id` should be a positive integer." - assert (eos_token_ids is None) or ( - isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) - ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." + assert (eos_token_id is None) or ( + isinstance(eos_token_id, int) and (eos_token_id >= 0) + ), "`eos_token_id` should be a positive integer." assert ( decoder_start_token_id is not None or self.config.is_encoder_decoder is False ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" @@ -674,11 +673,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): elif attention_mask is None: attention_mask = tf.ones_like(input_ids) - if pad_token_id is None and eos_token_ids is not None: + if pad_token_id is None and eos_token_id is not None: logger.warning( - "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) ) - pad_token_id = eos_token_ids[0] + pad_token_id = eos_token_id # current position and vocab size cur_len = shape_list(input_ids)[1] @@ -742,7 +741,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): no_repeat_ngram_size=no_repeat_ngram_size, bos_token_id=bos_token_id, pad_token_id=pad_token_id, - eos_token_ids=eos_token_ids, + eos_token_id=eos_token_id, decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, num_return_sequences=num_return_sequences, @@ -766,7 +765,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): no_repeat_ngram_size=no_repeat_ngram_size, bos_token_id=bos_token_id, pad_token_id=pad_token_id, - eos_token_ids=eos_token_ids, + eos_token_id=eos_token_id, decoder_start_token_id=decoder_start_token_id, batch_size=effective_batch_size, vocab_size=vocab_size, @@ -790,7 +789,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): no_repeat_ngram_size, bos_token_id, pad_token_id, - eos_token_ids, + eos_token_id, decoder_start_token_id, batch_size, vocab_size, @@ -839,10 +838,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) # set eos token prob to zero if min_length is not reached - if eos_token_ids is not None and cur_len < min_length: - # create eos_token_ids boolean mask + if eos_token_id is not None and cur_len < min_length: + # create eos_token_id boolean mask is_token_logit_eos_token = tf.convert_to_tensor( - [True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool + [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool ) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) @@ -865,28 +864,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32) # update generations and finished sentences - if eos_token_ids is not None: - # pad finished sentences if eos_token_ids exist + if eos_token_id is not None: + # pad finished sentences if eos_token_id exist tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) else: tokens_to_add = next_token input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1) - if eos_token_ids is not None: - for eos_token_id in eos_token_ids: - eos_in_sents = tokens_to_add == eos_token_id - # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length - is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( - unfinished_sents, tf.cast(eos_in_sents, tf.int32) - ) - sent_lengths = ( - sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos) - + cur_len * is_sents_unfinished_and_token_to_add_is_eos - ) + if eos_token_id is not None: + eos_in_sents = tokens_to_add == eos_token_id + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length + is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( + unfinished_sents, tf.cast(eos_in_sents, tf.int32) + ) + sent_lengths = ( + sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos) + + cur_len * is_sents_unfinished_and_token_to_add_is_eos + ) - # unfinished_sents is set to zero if eos in sentence - unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos + # unfinished_sents is set to zero if eos in sentence + unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos # stop when there is a in each sentence, or if we exceed the maximul length if tf.math.reduce_max(unfinished_sents) == 0: @@ -937,8 +935,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): no_repeat_ngram_size, bos_token_id, pad_token_id, - eos_token_ids, decoder_start_token_id, + eos_token_id, batch_size, num_return_sequences, length_penalty, @@ -996,10 +994,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) # set eos token prob to zero if min_length is not reached - if eos_token_ids is not None and cur_len < min_length: - # create eos_token_ids boolean mask + if eos_token_id is not None and cur_len < min_length: + # create eos_token_id boolean mask is_token_logit_eos_token = tf.convert_to_tensor( - [True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool + [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool ) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) @@ -1072,7 +1070,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): len(generated_hyps[batch_idx]) >= num_beams ), "Batch can only be done if at least {} beams have been generated".format(num_beams) assert ( - eos_token_ids is not None and pad_token_id is not None + eos_token_id is not None and pad_token_id is not None ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch continue @@ -1091,7 +1089,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): effective_beam_id = batch_idx * num_beams + beam_id # add to generated hypotheses if end of sentence or last iteration - if eos_token_ids is not None and token_id.numpy() in eos_token_ids: + if eos_token_id is not None and token_id.numpy() is eos_token_id: # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams if is_beam_token_worse_than_top_num_beams: @@ -1148,8 +1146,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if done[batch_idx]: continue # test that beam scores match previously calculated scores if not eos and batch_idx not done - if eos_token_ids is not None and all( - (token_id % vocab_size).numpy().item() not in eos_token_ids for token_id in next_tokens[batch_idx] + if eos_token_id is not None and all( + (token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx] ): assert tf.reduce_all( next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx] @@ -1199,7 +1197,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if sent_lengths[i] < max_length: decoded_hypo = tf.where( tf.range(max_length) == sent_lengths[i], - eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), + eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo, ) decoded_list.append(decoded_hypo) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 97bee18091..808c160094 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -665,7 +665,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): repetition_penalty=None, bos_token_id=None, pad_token_id=None, - eos_token_ids=None, + eos_token_id=None, length_penalty=None, no_repeat_ngram_size=None, num_return_sequences=None, @@ -713,6 +713,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): repetition_penalty: (`optional`) float The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0. + pad_token_id: (`optional`) int + Padding token. Default to specicic model pad_token_id or None if it does not exist. + bos_token_id: (`optional`) int BOS token. Defaults to bos_token_id as defined in the models config. @@ -798,7 +801,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size @@ -812,8 +815,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): batch_size = input_ids.shape[0] # overriden by the input batch_size else: batch_size = 1 - if isinstance(eos_token_ids, int): - eos_token_ids = [eos_token_ids] assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." @@ -830,12 +831,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert pad_token_id is None or ( isinstance(pad_token_id, int) and (pad_token_id >= 0) ), "`pad_token_id` should be a positive integer." - assert (eos_token_ids is None) or ( - isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) - ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." assert ( decoder_start_token_id is not None or self.config.is_encoder_decoder is False ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" + assert (eos_token_id is None) or ( + isinstance(eos_token_id, int) and (eos_token_id >= 0) + ), "`eos_token_id` should be a positive integer." assert length_penalty > 0, "`length_penalty` should be strictly positive." assert ( isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 @@ -876,13 +877,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): elif attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) - # set pad_token_id to eos_token_ids if not set. Important that this is done after + # set pad_token_id to eos_token_id if not set. Important that this is done after # attention_mask is created - if pad_token_id is None and eos_token_ids is not None: + if pad_token_id is None and eos_token_id is not None: logger.warning( - "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) + "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) ) - pad_token_id = eos_token_ids[0] + pad_token_id = eos_token_id # current position and vocab size vocab_size = self.config.vocab_size @@ -947,8 +948,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size=no_repeat_ngram_size, bos_token_id=bos_token_id, pad_token_id=pad_token_id, - eos_token_ids=eos_token_ids, decoder_start_token_id=decoder_start_token_id, + eos_token_id=eos_token_id, batch_size=effective_batch_size, num_return_sequences=num_return_sequences, length_penalty=length_penalty, @@ -971,8 +972,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size=no_repeat_ngram_size, bos_token_id=bos_token_id, pad_token_id=pad_token_id, - eos_token_ids=eos_token_ids, decoder_start_token_id=decoder_start_token_id, + eos_token_id=eos_token_id, batch_size=effective_batch_size, encoder_outputs=encoder_outputs, attention_mask=attention_mask, @@ -994,7 +995,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size, bos_token_id, pad_token_id, - eos_token_ids, + eos_token_id, decoder_start_token_id, batch_size, encoder_outputs, @@ -1031,9 +1032,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") # set eos token prob to zero if min_length is not reached - if eos_token_ids is not None and cur_len < min_length: - for eos_token_id in eos_token_ids: - next_token_logits[:, eos_token_id] = -float("inf") + if eos_token_id is not None and cur_len < min_length: + next_token_logits[:, eos_token_id] = -float("inf") if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) @@ -1049,22 +1049,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): next_token = torch.argmax(next_token_logits, dim=-1) # update generations and finished sentences - if eos_token_ids is not None: - # pad finished sentences if eos_token_ids exist + if eos_token_id is not None: + # pad finished sentences if eos_token_id exist tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) else: tokens_to_add = next_token input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) - if eos_token_ids is not None: - for eos_token_id in eos_token_ids: - eos_in_sents = tokens_to_add == eos_token_id - # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length - is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() - sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1) - # unfinished_sents is set to zero if eos in sentence - unfinished_sents.mul_((~eos_in_sents).long()) + if eos_token_id is not None: + eos_in_sents = tokens_to_add == eos_token_id + # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length + is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() + sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1) + # unfinished_sents is set to zero if eos in sentence + unfinished_sents.mul_((~eos_in_sents).long()) # stop when there is a in each sentence, or if we exceed the maximul length if unfinished_sents.max() == 0: @@ -1106,7 +1105,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size, bos_token_id, pad_token_id, - eos_token_ids, + eos_token_id, decoder_start_token_id, batch_size, num_return_sequences, @@ -1163,9 +1162,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length) # set eos token prob to zero if min_length is not reached - if eos_token_ids is not None and cur_len < min_length: - for eos_token_id in eos_token_ids: - scores[:, eos_token_id] = -float("inf") + if eos_token_id is not None and cur_len < min_length: + scores[:, eos_token_id] = -float("inf") if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams @@ -1225,7 +1223,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): len(generated_hyps[batch_idx]) >= num_beams ), "Batch can only be done if at least {} beams have been generated".format(num_beams) assert ( - eos_token_ids is not None and pad_token_id is not None + eos_token_id is not None and pad_token_id is not None ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch continue @@ -1244,7 +1242,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): effective_beam_id = batch_idx * num_beams + beam_id # add to generated hypotheses if end of sentence - if (eos_token_ids is not None) and (token_id.item() in eos_token_ids): + if (eos_token_id is not None) and (token_id.item() is eos_token_id): # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams if is_beam_token_worse_than_top_num_beams: @@ -1303,8 +1301,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): continue # test that beam scores match previously calculated scores if not eos and batch_idx not done - if eos_token_ids is not None and all( - (token_id % vocab_size).item() not in eos_token_ids for token_id in next_tokens[batch_idx] + if eos_token_id is not None and all( + (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx] ): assert torch.all( next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] @@ -1346,7 +1344,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): for i, hypo in enumerate(best): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < max_length: - decoded[i, sent_lengths[i]] = eos_token_ids[0] + decoded[i, sent_lengths[i]] = eos_token_id else: # none of the hypotheses have an eos_token assert (len(hypo) == max_length for hypo in best) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 4e26ee68c1..d064f0f780 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -61,7 +61,7 @@ class ModelTester: self.hidden_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1 self.max_position_embeddings = 20 - self.eos_token_ids = [2] + self.eos_token_id = 2 self.pad_token_id = 1 self.bos_token_id = 0 torch.manual_seed(0) @@ -82,7 +82,7 @@ class ModelTester: dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, - eos_token_ids=self.eos_token_ids, + eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, ) @@ -214,7 +214,7 @@ class BartHeadTests(unittest.TestCase): decoder_ffn_dim=32, max_position_embeddings=48, output_past=output_past, - eos_token_ids=[2], + eos_token_id=2, pad_token_id=1, bos_token_id=0, ) @@ -274,7 +274,7 @@ class BartHeadTests(unittest.TestCase): decoder_ffn_dim=32, max_position_embeddings=48, output_past=True, - eos_token_ids=[2], + eos_token_id=2, pad_token_id=1, bos_token_id=0, ) @@ -483,7 +483,7 @@ class BartModelIntegrationTests(unittest.TestCase): no_repeat_ngram_size=3, do_sample=False, early_stopping=True, - decoder_start_token_id=hf.config.eos_token_ids[0], + decoder_start_token_id=hf.config.eos_token_id, ) decoded = [ diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index bd5fe6f32c..74c2d9011f 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -132,7 +132,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range bos_token_id=self.bos_token_id, - eos_token_ids=self.eos_token_id, + eos_token_id=self.eos_token_id, ) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 6399a60f57..767fa3a2d0 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -130,7 +130,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range bos_token_id=self.bos_token_id, - eos_token_ids=self.eos_token_id, + eos_token_id=self.eos_token_id, ) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index 819cfc7a33..b432f49e3c 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -107,7 +107,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): d_inner=self.d_inner, div_val=self.div_val, n_layer=self.num_hidden_layers, - eos_token_ids=self.eos_token_id, + eos_token_id=self.eos_token_id, ) return (config, input_ids_1, input_ids_2, lm_labels) diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index f0152dfd40..18212c8828 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -103,7 +103,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): d_inner=self.d_inner, div_val=self.div_val, n_layer=self.num_hidden_layers, - eos_token_ids=self.eos_token_id, + eos_token_id=self.eos_token_id, ) return (config, input_ids_1, input_ids_2, lm_labels)