From c130e67dce56a092604949a8df6384a17f762189 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 10 Feb 2021 22:39:09 +0530 Subject: [PATCH] remove adjust_logits_during_generation method (#10087) * add forced logits processors * delete adjust_logits method * add forced_eos_token_id argument in config * add tests for forced logits processors * update gen utils tests * add forced option to tf generate * remove adjust_logits method from tf models * update adjust_logits for marian * delete _force_token_id_to_be_generated method * style * import warnings * pass max_length to _get_logits_processor * set forced_eos_token_id to None * set forced attributes in conf utils * typo * fix rag generate * add forced_eos_token_id in rag config * remove force_bos_token_to_be_generated from BartConfig * remove _force_token_ids_generation from FSMT * nit * fix negative constant * apply suggestions from code review --- src/transformers/configuration_utils.py | 7 + src/transformers/generation_logits_process.py | 46 ++++++ src/transformers/generation_tf_utils.py | 37 ++++- src/transformers/generation_utils.py | 35 ++++- .../models/bart/configuration_bart.py | 19 ++- src/transformers/models/bart/modeling_bart.py | 12 -- .../models/bart/modeling_tf_bart.py | 10 -- .../blenderbot/configuration_blenderbot.py | 5 + .../models/blenderbot/modeling_blenderbot.py | 10 -- .../blenderbot/modeling_tf_blenderbot.py | 7 - .../configuration_blenderbot_small.py | 5 + .../modeling_blenderbot_small.py | 10 -- .../modeling_tf_blenderbot_small.py | 7 - .../models/fsmt/configuration_fsmt.py | 5 + src/transformers/models/fsmt/modeling_fsmt.py | 17 --- .../models/marian/configuration_marian.py | 5 + .../models/marian/modeling_marian.py | 7 - .../models/marian/modeling_tf_marian.py | 15 +- .../models/mbart/configuration_mbart.py | 5 + .../models/mbart/modeling_mbart.py | 10 -- .../models/mbart/modeling_tf_mbart.py | 7 - .../models/pegasus/configuration_pegasus.py | 5 + .../models/pegasus/modeling_pegasus.py | 10 -- .../models/pegasus/modeling_tf_pegasus.py | 7 - .../models/rag/configuration_rag.py | 8 ++ src/transformers/models/rag/modeling_rag.py | 14 +- tests/test_generation_logits_process.py | 43 ++++++ tests/test_generation_utils.py | 132 +++++++++++++----- tests/test_pipelines_summarization.py | 1 + 29 files changed, 335 insertions(+), 166 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 6a982d8b2d..0fba6fc32a 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -131,6 +131,11 @@ class PretrainedConfig(object): logits when used for generation - **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor` + - **forced_bos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the first generated token + after the :obj:`decoder_start_token_id`. Useful for multilingual models like :doc:`mBART + <../model_doc/mbart>` where the first generated token needs to be the target language token. + - **forced_eos_token_id** (:obj:`int`, `optional`) -- The id of the token to force as the last generated token + when :obj:`max_length` is reached. Parameters for fine-tuning tasks @@ -214,6 +219,8 @@ class PretrainedConfig(object): self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) self.output_scores = kwargs.pop("output_scores", False) self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) + self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) + self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) # Fine-tuning task arguments self.architectures = kwargs.pop("architectures", None) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 85d2c9df36..8d42aba12a 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -520,3 +520,49 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency return scores + + +class ForcedBOSTokenLogitsProcessor(LogitsProcessor): + r""" + :class:`~transformers.LogitsProcessor` that enforces the specified token as the first generated token. + + Args: + bos_token_id (:obj:`int`): + The id of the token to force as the first generated token. + """ + + def __init__(self, bos_token_id: int): + self.bos_token_id = bos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + if cur_len == 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf") + scores[:, self.bos_token_id] = 0 + return scores + + +class ForcedEOSTokenLogitsProcessor(LogitsProcessor): + r""" + :class:`~transformers.LogitsProcessor` that enforces the specified token as the last generated token when + :obj:`max_length` is reached. + + Args: + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + eos_token_id (:obj:`int`): + The id of the token to force as the last generated token when :obj:`max_length` is reached. + """ + + def __init__(self, max_length: int, eos_token_id: int): + self.max_length = max_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + if cur_len == self.max_length - 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf") + scores[:, self.eos_token_id] = 0 + return scores diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index e356157ba4..1c723365ec 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -67,6 +67,8 @@ class TFGenerationMixin: attention_mask=None, decoder_start_token_id=None, use_cache=None, + forced_bos_token_id=None, + forced_eos_token_id=None, ): r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, @@ -137,6 +139,12 @@ class TFGenerationMixin: use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + forced_bos_token_id (:obj:`int`, `optional`): + The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`. + Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token + needs to be the target language token. + forced_eos_token_id (:obj:`int`, `optional`): + The id of the token to force as the last generated token when :obj:`max_length` is reached. model_specific_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. @@ -214,6 +222,12 @@ class TFGenerationMixin: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) + forced_bos_token_id = ( + forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id + ) + forced_eos_token_id = ( + forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id + ) if input_ids is not None: batch_size = shape_list(input_ids)[0] # overridden by the input batch_size @@ -380,6 +394,8 @@ class TFGenerationMixin: encoder_outputs=encoder_outputs, attention_mask=attention_mask, use_cache=use_cache, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, ) else: output = self._generate_no_beam_search( @@ -591,6 +607,8 @@ class TFGenerationMixin: encoder_outputs, attention_mask, use_cache, + forced_bos_token_id, + forced_eos_token_id, ): """Generate sequences for each example with beam search.""" @@ -641,7 +659,11 @@ class TFGenerationMixin: if self.config.is_encoder_decoder and do_sample is False: next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=max_length + next_token_logits, + cur_len=cur_len, + max_length=max_length, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, ) # calculate log softmax score scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) @@ -893,12 +915,21 @@ class TFGenerationMixin: def _reorder_cache(past, beam_idx): return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past) - def adjust_logits_during_generation(self, logits, **kwargs): + def adjust_logits_during_generation( + self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs + ): """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in the generate method. """ - return logits + if cur_len == 1 and forced_bos_token_id is not None: + vocab_range = tf.constant(range(self.config.vocab_size)) + return tf.where(vocab_range != forced_bos_token_id, -1e8, logits) + elif cur_len == max_length - 1 and forced_eos_token_id is not None: + vocab_range = tf.constant(range(self.config.vocab_size)) + return tf.where(vocab_range != forced_eos_token_id, -1e8, logits) + else: + return logits def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index fcf01ab401..0ce5df28b9 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -24,6 +24,8 @@ from .file_utils import ModelOutput from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, @@ -542,7 +544,10 @@ class GenerationMixin: encoder_input_ids: torch.LongTensor, bad_words_ids: List[List[int]], min_length: int, + max_length: int, eos_token_id: int, + forced_bos_token_id: int, + forced_eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, num_beam_groups: int, @@ -567,6 +572,12 @@ class GenerationMixin: min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty + forced_bos_token_id = ( + forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id + ) + forced_eos_token_id = ( + forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id + ) # instantiate processors list processors = LogitsProcessorList() @@ -595,6 +606,10 @@ class GenerationMixin: processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) if prefix_allowed_tokens_fn is not None: processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams)) + if forced_bos_token_id is not None: + processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) + if forced_eos_token_id is not None: + processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) return processors @torch.no_grad() @@ -627,6 +642,8 @@ class GenerationMixin: output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" @@ -720,6 +737,12 @@ class GenerationMixin: Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details. return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + forced_bos_token_id (:obj:`int`, `optional`): + The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`. + Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token + needs to be the target language token. + forced_eos_token_id (:obj:`int`, `optional`): + The id of the token to force as the last generated token when :obj:`max_length` is reached. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the @@ -888,7 +911,10 @@ class GenerationMixin: encoder_input_ids=encoder_input_ids, bad_words_ids=bad_words_ids, min_length=min_length, + max_length=max_length, eos_token_id=eos_token_id, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, num_beam_groups=num_beam_groups, @@ -1611,7 +1637,8 @@ class GenerationMixin: ) next_token_logits = outputs.logits[:, -1, :] - # adjust tokens for Bart, *e.g.* + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `F.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation( next_token_logits, cur_len=cur_len, max_length=max_length ) @@ -1866,7 +1893,8 @@ class GenerationMixin: ) next_token_logits = outputs.logits[:, -1, :] - # adjust token scores (a no-op by default) + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `F.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation( next_token_logits, cur_len=cur_len, max_length=max_length ) @@ -2150,7 +2178,8 @@ class GenerationMixin: # select outputs of beams of current group only next_token_logits = outputs.logits[batch_group_indices, -1, :] - # adjust tokens for Bart, *e.g.* + # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` + # cannot be generated both before and after the `F.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation( next_token_logits, cur_len=cur_len, max_length=max_length ) diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 1bbdfabcd3..0ea94f76b8 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ BART model configuration """ +import warnings from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -72,9 +73,6 @@ class BartConfig(PretrainedConfig): just in case (e.g., 512 or 1024 or 2048). init_std (:obj:`float`, `optional`, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only - :obj:`True` for `bart-large-cnn`. encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the encoder. See the `LayerDrop paper `__ for more details. @@ -89,6 +87,9 @@ class BartConfig(PretrainedConfig): Whether or not the model should return the last key/values attentions (not used by all models). num_labels: (:obj:`int`, `optional`, defaults to 3): The number of labels to use in :class:`~transformers.BartForSequenceClassification`. + forced_eos_token_id (:obj:`int`, `optional`, defaults to 2): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. Example:: @@ -127,7 +128,6 @@ class BartConfig(PretrainedConfig): classifier_dropout=0.0, scale_embedding=False, gradient_checkpointing=False, - force_bos_token_to_be_generated=False, use_cache=True, num_labels=3, pad_token_id=1, @@ -135,6 +135,7 @@ class BartConfig(PretrainedConfig): eos_token_id=2, is_encoder_decoder=True, decoder_start_token_id=2, + forced_eos_token_id=2, **kwargs ): super().__init__( @@ -144,6 +145,7 @@ class BartConfig(PretrainedConfig): eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, **kwargs, ) @@ -168,7 +170,14 @@ class BartConfig(PretrainedConfig): self.num_hidden_layers = encoder_layers self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - self.force_bos_token_to_be_generated = force_bos_token_to_be_generated # only relevant for CNN + + # ensure backward compatibilty for BART CNN models + if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): + self.forced_bos_token_id = self.bos_token_id + warnings.warn( + f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." + "The config can simply be saved and uploaded again to be fixed." + ) @property def num_attention_heads(self) -> int: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a595e25d72..e538fe456a 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1344,18 +1344,6 @@ class BartForConditionalGeneration(BartPretrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == 1 and self.config.force_bos_token_to_be_generated: - self._force_token_id_to_be_generated(logits, self.config.bos_token_id) - elif cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_id_to_be_generated(logits, self.config.eos_token_id) - return logits - - @staticmethod - def _force_token_id_to_be_generated(scores, token_id) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" - scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") - @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index e1dbf0c23b..6d00beafc0 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1444,13 +1444,3 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode + layer_past_key_values[2:], ) return (past[0], reordered_past) - - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == 1 and self.config.force_bos_token_to_be_generated: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != self.config.bos_token_id, LARGE_NEGATIVE, logits) - elif cur_len == max_length - 1: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) - else: - return logits diff --git a/src/transformers/models/blenderbot/configuration_blenderbot.py b/src/transformers/models/blenderbot/configuration_blenderbot.py index 4de6a9d12a..1712d7cbf6 100644 --- a/src/transformers/models/blenderbot/configuration_blenderbot.py +++ b/src/transformers/models/blenderbot/configuration_blenderbot.py @@ -84,6 +84,9 @@ class BlenderbotConfig(PretrainedConfig): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (:obj:`int`, `optional`, defaults to 2): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. Example:: @@ -129,6 +132,7 @@ class BlenderbotConfig(PretrainedConfig): bos_token_id=1, eos_token_id=2, encoder_no_repeat_ngram_size=3, + forced_eos_token_id=2, **kwargs ): super().__init__( @@ -138,6 +142,7 @@ class BlenderbotConfig(PretrainedConfig): is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + forced_eos_token_id=forced_eos_token_id, **kwargs, ) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index cfe826d1af..18f6d308f1 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1335,16 +1335,6 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_id_to_be_generated(logits, self.config.eos_token_id) - return logits - - @staticmethod - def _force_token_id_to_be_generated(scores, token_id) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" - scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") - @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index c5463d1fc3..c78392a4f3 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1477,10 +1477,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal + layer_past_key_values[2:], ) return (past[0], reordered_past) - - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) - else: - return logits diff --git a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py index b7bde44dda..9961980124 100644 --- a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py @@ -84,6 +84,9 @@ class BlenderbotSmallConfig(PretrainedConfig): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (:obj:`int`, `optional`, defaults to 2): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. Example:: @@ -128,6 +131,7 @@ class BlenderbotSmallConfig(PretrainedConfig): pad_token_id=0, bos_token_id=1, eos_token_id=2, + forced_eos_token_id=2, **kwargs ): super().__init__( @@ -136,6 +140,7 @@ class BlenderbotSmallConfig(PretrainedConfig): eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, **kwargs, ) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index ffc83ed187..3c520c941a 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1310,16 +1310,6 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): "use_cache": use_cache, # change this to avoid caching (presumably for debugging) } - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_id_to_be_generated(logits, self.config.eos_token_id) - return logits - - @staticmethod - def _force_token_id_to_be_generated(scores, token_id) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" - scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") - @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index b6fb43081f..8500d177fa 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1452,10 +1452,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel + layer_past_key_values[2:], ) return (past[0], reordered_past) - - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) - else: - return logits diff --git a/src/transformers/models/fsmt/configuration_fsmt.py b/src/transformers/models/fsmt/configuration_fsmt.py index 1c1b1f6548..d7a79298c7 100644 --- a/src/transformers/models/fsmt/configuration_fsmt.py +++ b/src/transformers/models/fsmt/configuration_fsmt.py @@ -111,6 +111,9 @@ class FSMTConfig(PretrainedConfig): search when at least ``num_beams`` sentences are finished per batch or not. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). + forced_eos_token_id (:obj:`int`, `optional`, defaults to 2): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. Examples:: @@ -155,6 +158,7 @@ class FSMTConfig(PretrainedConfig): pad_token_id=1, bos_token_id=0, eos_token_id=2, + forced_eos_token_id=2, **common_kwargs ): if "hidden_size" in common_kwargs: @@ -166,6 +170,7 @@ class FSMTConfig(PretrainedConfig): decoder_start_token_id=decoder_start_token_id, is_encoder_decoder=is_encoder_decoder, tie_word_embeddings=tie_word_embeddings, + forced_eos_token_id=forced_eos_token_id, **common_kwargs, ) self.langs = langs diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index c2fcf9c8eb..5ad1a0ca7b 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1210,23 +1210,6 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_ids_generation(logits, self.config.eos_token_id) - return logits - - def _force_token_ids_generation(self, scores, token_ids) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0""" - if isinstance(token_ids, int): - token_ids = [token_ids] - all_but_token_ids_mask = torch.tensor( - [x for x in range(self.config.tgt_vocab_size) if x not in token_ids], - dtype=torch.long, - device=next(self.parameters()).device, - ) - assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]" - scores[:, all_but_token_ids_mask] = -float("inf") - @staticmethod def _reorder_cache(past, beam_idx): reordered_past = [] diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index 7e57b6e975..15893eef30 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -84,6 +84,9 @@ class MarianConfig(PretrainedConfig): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (:obj:`int`, `optional`, defaults to 0): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. Examples:: @@ -127,6 +130,7 @@ class MarianConfig(PretrainedConfig): gradient_checkpointing=False, pad_token_id=58100, eos_token_id=0, + forced_eos_token_id=0, **kwargs ): super().__init__( @@ -134,6 +138,7 @@ class MarianConfig(PretrainedConfig): eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, **kwargs, ) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 24011e31f5..5dbd782090 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1325,15 +1325,8 @@ class MarianMTModel(MarianPreTrainedModel): def adjust_logits_during_generation(self, logits, cur_len, max_length): logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. - if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_id_to_be_generated(logits, self.config.eos_token_id) return logits - @staticmethod - def _force_token_id_to_be_generated(scores, token_id) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" - scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") - @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 674bf1b52e..27004bf955 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1470,10 +1470,17 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ) return (past[0], reordered_past) - def adjust_logits_during_generation(self, logits, cur_len, max_length): + def adjust_logits_during_generation( + self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs + ): """Never predict pad_token_id. Predict when max_length is reached.""" vocab_range = tf.constant(range(self.config.vocab_size)) logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits) - if cur_len == max_length - 1: - logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) - return logits + if cur_len == 1 and forced_bos_token_id is not None: + vocab_range = tf.constant(range(self.config.vocab_size)) + return tf.where(vocab_range != forced_bos_token_id, LARGE_NEGATIVE, logits) + elif cur_len == max_length - 1 and forced_eos_token_id is not None: + vocab_range = tf.constant(range(self.config.vocab_size)) + return tf.where(vocab_range != forced_eos_token_id, LARGE_NEGATIVE, logits) + else: + return logits diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index b81a713fd9..d8f8364850 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -84,6 +84,9 @@ class MBartConfig(PretrainedConfig): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (:obj:`int`, `optional`, defaults to 2): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. Example:: @@ -127,6 +130,7 @@ class MBartConfig(PretrainedConfig): pad_token_id=1, bos_token_id=0, eos_token_id=2, + forced_eos_token_id=2, **kwargs ): super().__init__( @@ -134,6 +138,7 @@ class MBartConfig(PretrainedConfig): bos_token_id=bos_token_id, eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, + forced_eos_token_id=forced_eos_token_id, **kwargs, ) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 2aef23d1c0..cde5f5974a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1344,16 +1344,6 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id) - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_id_to_be_generated(logits, self.config.eos_token_id) - return logits - - @staticmethod - def _force_token_id_to_be_generated(scores, token_id) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" - scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") - @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 94d4e814f5..6c7726b596 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1468,10 +1468,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo + layer_past_key_values[2:], ) return (past[0], reordered_past) - - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) - else: - return logits diff --git a/src/transformers/models/pegasus/configuration_pegasus.py b/src/transformers/models/pegasus/configuration_pegasus.py index 0ed78b25fa..424458590c 100644 --- a/src/transformers/models/pegasus/configuration_pegasus.py +++ b/src/transformers/models/pegasus/configuration_pegasus.py @@ -84,6 +84,9 @@ class PegasusConfig(PretrainedConfig): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) + forced_eos_token_id (:obj:`int`, `optional`, defaults to 1): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. Example:: @@ -127,6 +130,7 @@ class PegasusConfig(PretrainedConfig): gradient_checkpointing=False, pad_token_id=0, eos_token_id=1, + forced_eos_token_id=1, **kwargs ): super().__init__( @@ -134,6 +138,7 @@ class PegasusConfig(PretrainedConfig): eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, **kwargs, ) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 36f7f13ca0..2350fa5027 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1327,16 +1327,6 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1 and self.config.eos_token_id is not None: - self._force_token_id_to_be_generated(logits, self.config.eos_token_id) - return logits - - @staticmethod - def _force_token_id_to_be_generated(scores, token_id) -> None: - """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" - scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") - @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 396d2c71c0..a94addbb21 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1483,10 +1483,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua + layer_past_key_values[2:], ) return (past[0], reordered_past) - - def adjust_logits_during_generation(self, logits, cur_len, max_length): - if cur_len == max_length - 1: - vocab_range = tf.constant(range(self.config.vocab_size)) - return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits) - else: - return logits diff --git a/src/transformers/models/rag/configuration_rag.py b/src/transformers/models/rag/configuration_rag.py index 5ca27c8afc..252d91660e 100644 --- a/src/transformers/models/rag/configuration_rag.py +++ b/src/transformers/models/rag/configuration_rag.py @@ -74,6 +74,9 @@ RAG_CONFIG_DOC = r""" :obj:`context_attention_mask` are returned. See returned tensors for more detail. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). + forced_eos_token_id (:obj:`int`, `optional`): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Usually set to + :obj:`eos_token_id`. """ @@ -110,6 +113,7 @@ class RagConfig(PretrainedConfig): do_marginalize=False, output_retrieved=False, use_cache=True, + forced_eos_token_id=None, **kwargs ): super().__init__( @@ -117,6 +121,7 @@ class RagConfig(PretrainedConfig): pad_token_id=pad_token_id, eos_token_id=eos_token_id, decoder_start_token_id=decoder_start_token_id, + forced_eos_token_id=forced_eos_token_id, is_encoder_decoder=is_encoder_decoder, prefix=prefix, vocab_size=vocab_size, @@ -161,6 +166,9 @@ class RagConfig(PretrainedConfig): self.use_cache = use_cache + if self.forced_eos_token_id is None: + self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None) + @classmethod def from_question_encoder_generator_configs( cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index b421751d27..5f893e11cd 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1089,9 +1089,6 @@ class RagTokenForGeneration(RagPreTrainedModel): def set_retriever(self, retriever: RagRetriever): self.rag.retriever = retriever - def adjust_logits_during_generation(self, logits, cur_len, max_length): - return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length) - def prepare_inputs_for_generation( self, decoder_input_ids, @@ -1313,6 +1310,8 @@ class RagTokenForGeneration(RagPreTrainedModel): decoder_start_token_id=None, n_docs=None, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None, + forced_bos_token_id: Optional[int] = None, + forced_eos_token_id: Optional[int] = None, **model_kwargs ): """ @@ -1403,6 +1402,12 @@ class RagTokenForGeneration(RagPreTrainedModel): conditioned on the previously generated tokens :obj:`inputs_ids` and the batch ID :obj:`batch_id`. This argument is useful for constrained generation conditioned on the prefix, as described in `Autoregressive Entity Retrieval `__. + forced_bos_token_id (:obj:`int`, `optional`): + The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`. + Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token + needs to be the target language token. + forced_eos_token_id (:obj:`int`, `optional`): + The id of the token to force as the last generated token when :obj:`max_length` is reached. Return: :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated @@ -1498,7 +1503,10 @@ class RagTokenForGeneration(RagPreTrainedModel): encoder_input_ids=context_input_ids, bad_words_ids=bad_words_ids, min_length=min_length, + max_length=max_length, eos_token_id=eos_token_id, + forced_bos_token_id=forced_bos_token_id, + forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, num_beam_groups=num_beam_groups, diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index 315417df33..85a589b7c2 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -28,6 +28,8 @@ if is_torch_available(): from transformers.generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, @@ -393,3 +395,44 @@ class LogitsProcessorTest(unittest.TestCase): processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3 ) ) + + def test_forced_bos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + bos_token_id = 0 + + logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id) + + # check that all scores are -inf except the bos_token_id score + input_ids = ids_tensor((batch_size, 1), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertTrue(torch.isneginf(scores[:, bos_token_id + 1 :]).all()) + self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero + + # check that bos_token_id is not forced if current length is greater than 1 + input_ids = ids_tensor((batch_size, 4), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertFalse(torch.isinf(scores).any()) + + def test_forced_eos_token_logits_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + max_length = 5 + + logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id) + + # check that all scores are -inf except the eos_token_id when max_length is reached + input_ids = ids_tensor((batch_size, 4), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all()) + self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero + + # check that eos_token_id is not forced if max_length is not reached + input_ids = ids_tensor((batch_size, 3), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores = logits_processor(input_ids, scores) + self.assertFalse(torch.isinf(scores).any()) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index e54a3579cb..d1f01a7ae1 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -26,6 +26,8 @@ if is_torch_available(): from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering from transformers.generation_beam_search import BeamSearchScorer from transformers.generation_logits_process import ( + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, @@ -70,7 +72,14 @@ class GenerationTesterMixin: return config, input_ids, attention_mask, max_length @staticmethod - def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None): + def _get_logits_processor_and_kwargs( + input_length, + eos_token_id, + forced_bos_token_id=None, + forced_eos_token_id=None, + max_length=None, + diversity_penalty=None, + ): process_kwargs = { "min_length": input_length + 1, "bad_words_ids": [[1, 0]], @@ -92,6 +101,18 @@ class GenerationTesterMixin: if eos_token_id is not None else [] ) + + ( + [ + ForcedBOSTokenLogitsProcessor(forced_bos_token_id), + ] + if forced_bos_token_id is not None + else [] + ) + + ( + [ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)] + if forced_eos_token_id is not None + else [] + ) + [ NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), @@ -182,13 +203,17 @@ class GenerationTesterMixin: output_hidden_states=False, return_dict_in_generate=False, ): + if model.config.is_encoder_decoder: + max_length = 4 logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], model.config.eos_token_id + input_ids.shape[-1], + eos_token_id=model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, ) kwargs = {} - if model.config.is_encoder_decoder: - max_length = 4 output_generate = model.generate( input_ids, @@ -544,14 +569,19 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], model.config.eos_token_id - ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) if model.config.is_encoder_decoder: max_length = 4 + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + # check `generate()` and `sample()` are equal output_sample, output_generate = self._sample_generate( model=model, @@ -586,14 +616,18 @@ class GenerationTesterMixin: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config.use_cache = False model = model_class(config).to(torch_device).eval() - process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], model.config.eos_token_id - ) - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) - if model.config.is_encoder_decoder: max_length = 4 + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + model.config.eos_token_id, + forced_bos_token_id=model.config.forced_bos_token_id, + forced_eos_token_id=model.config.forced_eos_token_id, + max_length=max_length, + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + output_sample, output_generate = self._sample_generate( model=model, input_ids=input_ids, @@ -630,14 +664,19 @@ class GenerationTesterMixin: # shorter than `max_length` can be generated which could lead to flaky circle ci # failures if the top `num_return_sequences` beams are all shorter than the longest beam config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id - ) if model.config.is_encoder_decoder: max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) # check `generate()` and `beam_search()` are equal @@ -684,13 +723,19 @@ class GenerationTesterMixin: # shorter than `max_length` can be generated which could lead to flaky circle ci # failures if the top `num_return_sequences` beams are all shorter than the longest beam config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id - ) if model.config.is_encoder_decoder: max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) output_generate, output_beam_search = self._beam_search_generate( model=model, @@ -732,19 +777,24 @@ class GenerationTesterMixin: # shorter than `max_length` can be generated which could lead to flaky circle ci # failures if the top `num_return_sequences` beams are all shorter than the longest beam config.eos_token_id = None + config.forced_eos_token_id = None if not hasattr(config, "use_cache"): # only relevant if model has "use_cache" return model = model_class(config).to(torch_device).eval() - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id - ) - if model.config.is_encoder_decoder: max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + ) + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) config.use_cache = True @@ -780,6 +830,7 @@ class GenerationTesterMixin: # shorter than `max_length` can be generated which could lead to flaky circle ci # failures if the top `num_return_sequences` beams are all shorter than the longest beam config.eos_token_id = None + config.forced_eos_token_id = None logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) @@ -819,6 +870,7 @@ class GenerationTesterMixin: # shorter than `max_length` can be generated which could lead to flaky circle ci # failures if the top `num_return_sequences` beams are all shorter than the longest beam config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) @@ -892,16 +944,22 @@ class GenerationTesterMixin: # shorter than `max_length` can be generated which could lead to flaky circle ci # failures if the top `num_return_sequences` beams are all shorter than the longest beam config.eos_token_id = None - - logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0 - ) + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() - - # check `generate()` and `group_beam_search()` are equal if model.config.is_encoder_decoder: max_length = 4 + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + diversity_penalty=2.0, + ) + + # check `generate()` and `group_beam_search()` are equal beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length) output_generate, output_group_beam_search = self._group_beam_search_generate( model=model, @@ -943,16 +1001,22 @@ class GenerationTesterMixin: # shorter than `max_length` can be generated which could lead to flaky circle ci # failures if the top `num_return_sequences` beams are all shorter than the longest beam config.eos_token_id = None + config.forced_eos_token_id = None model = model_class(config).to(torch_device).eval() + if model.config.is_encoder_decoder: + max_length = 4 logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( - input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0 + input_ids.shape[-1], + config.eos_token_id, + config.forced_bos_token_id, + config.forced_eos_token_id, + max_length, + diversity_penalty=2.0, ) num_return_sequences = 1 - if model.config.is_encoder_decoder: - max_length = 4 beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs( input_ids.shape[0], max_length, num_return_sequences=num_return_sequences ) diff --git a/tests/test_pipelines_summarization.py b/tests/test_pipelines_summarization.py index 2d2bc3330d..17f952a2c2 100644 --- a/tests/test_pipelines_summarization.py +++ b/tests/test_pipelines_summarization.py @@ -46,6 +46,7 @@ class SimpleSummarizationPipelineTests(unittest.TestCase): decoder_attention_heads=1, max_length=4, min_length=1, + forced_eos_token_id=None, ) model = BartForConditionalGeneration(config) # Bias output towards L