diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 3738a4cae7..0221622c40 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -158,9 +158,6 @@ generation. [[autodoc]] LogitsProcessorList - __call__ -[[autodoc]] LogitsWarper - - __call__ - [[autodoc]] MinLengthLogitsProcessor - __call__ @@ -421,4 +418,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] WatermarkDetector - __call__ - diff --git a/docs/source/ja/internal/generation_utils.md b/docs/source/ja/internal/generation_utils.md index d65067fc0b..9e3ce77995 100644 --- a/docs/source/ja/internal/generation_utils.md +++ b/docs/source/ja/internal/generation_utils.md @@ -157,9 +157,6 @@ generation_output[:2] [[autodoc]] LogitsProcessorList - __call__ -[[autodoc]] LogitsWarper - - __call__ - [[autodoc]] MinLengthLogitsProcessor - __call__ diff --git a/docs/source/zh/internal/generation_utils.md b/docs/source/zh/internal/generation_utils.md index c82deecd3d..75f28c233e 100644 --- a/docs/source/zh/internal/generation_utils.md +++ b/docs/source/zh/internal/generation_utils.md @@ -151,9 +151,6 @@ generation_output[:2] [[autodoc]] LogitsProcessorList - __call__ -[[autodoc]] LogitsWarper - - __call__ - [[autodoc]] MinLengthLogitsProcessor - __call__ diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index c7e626f1a7..aa5e77ac68 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -190,9 +190,9 @@ class GenerationConfig(PushToHubMixin): triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. renormalize_logits (`bool`, *optional*, defaults to `False`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the custom + Whether to renormalize the logits after applying all the logits processors (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits - are normalized but some logit processors or warpers break the normalization. + are normalized but some logit processors break the normalization. constraints (`List[Constraint]`, *optional*): Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible. diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index b226a059d1..7f89e23924 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -55,6 +55,12 @@ class LogitsProcessor: class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + def __init__(self): + logger.warning_once( + "`LogitsWarper` is deprecated and will be removed in v4.48. Your class should inherit `LogitsProcessor` " + "instead, which has the same properties and interface." + ) + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: raise NotImplementedError( @@ -64,9 +70,9 @@ class LogitsWarper: class LogitsProcessorList(list): """ - This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a - `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each - [`LogitsProcessor`] or [`LogitsWarper`] to the inputs. + This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor. + This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the + inputs. """ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: @@ -233,9 +239,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): return scores_processed -class TemperatureLogitsWarper(LogitsWarper): +class TemperatureLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means + [`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and [`TopKLogitsWarper`]. @@ -408,10 +414,10 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): return scores_processed -class TopPLogitsWarper(LogitsWarper): +class TopPLogitsWarper(LogitsProcessor): """ - [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often - used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. + [`LogitsProcessor`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`]. Args: top_p (`float`): @@ -475,10 +481,10 @@ class TopPLogitsWarper(LogitsWarper): return scores_processed -class TopKLogitsWarper(LogitsWarper): +class TopKLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together - with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. + [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used + together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`]. Args: top_k (`int`): @@ -528,9 +534,9 @@ class TopKLogitsWarper(LogitsWarper): return scores_processed -class MinPLogitsWarper(LogitsWarper): +class MinPLogitsWarper(LogitsProcessor): """ - [`LogitsWarper`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the + [`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter becomes more agressive in the presence of high-probability tokens, which is a sign of a confident output that we shouldn't deviate from. @@ -605,11 +611,11 @@ class MinPLogitsWarper(LogitsWarper): return scores_processed -class TypicalLogitsWarper(LogitsWarper): +class TypicalLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose - log probability is close to the entropy of the token probability distribution. This means that the most likely - tokens may be discarded in the process. + [`LogitsProcessor`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens + whose log probability is close to the entropy of the token probability distribution. This means that the most + likely tokens may be discarded in the process. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. @@ -693,9 +699,9 @@ class TypicalLogitsWarper(LogitsWarper): return scores_processed -class EpsilonLogitsWarper(LogitsWarper): +class EpsilonLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the + [`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. @@ -762,15 +768,15 @@ class EpsilonLogitsWarper(LogitsWarper): return scores_processed -class EtaLogitsWarper(LogitsWarper): +class EtaLogitsWarper(LogitsProcessor): r""" - [`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic + [`LogitsProcessor`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample` - must be set to `True` for this `LogitsWarper` to work. + must be set to `True` for this `LogitsProcessor` to work. Args: @@ -1708,9 +1714,9 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): return scores_processed -class LogitNormalization(LogitsProcessor, LogitsWarper): +class LogitNormalization(LogitsProcessor): r""" - [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize + [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that the scores are normalized when comparing the hypotheses. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 998288bd38..24c9e3bb18 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -735,61 +735,6 @@ class GenerationMixin: ) return candidate_generator - def _get_logits_warper( - self, - generation_config: GenerationConfig, - device: str, - ) -> LogitsProcessorList: - """ - This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances - used for multinomial sampling. - """ - - # instantiate warpers list - warpers = LogitsProcessorList() - - # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a - # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) - if generation_config.num_beams > 1: - if isinstance(generation_config._eos_token_tensor, list): - min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 - elif isinstance(generation_config._eos_token_tensor, torch.Tensor): - min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 - else: - min_tokens_to_keep = 2 - else: - min_tokens_to_keep = 1 - - # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files - # all samplers can be found in `generation_utils_samplers.py` - if generation_config.temperature is not None and generation_config.temperature != 1.0: - warpers.append(TemperatureLogitsWarper(generation_config.temperature)) - if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.top_p is not None and generation_config.top_p < 1.0: - warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.min_p is not None: - # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) - warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)) - if generation_config.typical_p is not None and generation_config.typical_p < 1.0: - warpers.append( - TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: - warpers.append( - EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) - ) - if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: - warpers.append( - EtaLogitsWarper( - epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device - ) - ) - # `LogitNormalization` should always be the last logit processor, when present - if generation_config.renormalize_logits is True: - warpers.append(LogitNormalization()) - return warpers - def _get_logits_processor( self, generation_config: GenerationConfig, @@ -960,7 +905,58 @@ class GenerationMixin: context_width=generation_config.watermarking_config.context_width, ) ) + + # TODO (joao): find a strategy to specify the order of the processors processors = self._merge_criteria_processor_list(processors, logits_processor) + + # Processors previously known as `LogitsWarpers`, only applied with sampling strategies + if generation_config.do_sample: + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + processors.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + processors.append( + TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + processors.append( + TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.min_p is not None: + # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084) + processors.append( + MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + processors.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + processors.append( + EpsilonLogitsWarper( + epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep + ) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + processors.append( + EtaLogitsWarper( + epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device + ) + ) + # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization()) @@ -1940,22 +1936,11 @@ class GenerationMixin: model_kwargs=model_kwargs, ) - # 12. prepare logits warper (if `do_sample` is `True`) - prepared_logits_warper = ( - self._get_logits_warper( - generation_config, - device=input_ids.device, - ) - if generation_config.do_sample - else None - ) - - # 13. run assisted generate + # 12. run assisted generate result = self._assisted_decoding( input_ids, candidate_generator=candidate_generator, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -1968,16 +1953,10 @@ class GenerationMixin: raise ValueError( f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}" ) - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) result = self._dola_decoding( input_ids, dola_layers=generation_config.dola_layers, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2005,14 +1984,7 @@ class GenerationMixin: ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - - # 12. expand input_ids with `num_return_sequences` additional sequences per batch + # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, @@ -2020,11 +1992,10 @@ class GenerationMixin: **model_kwargs, ) - # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2033,14 +2004,7 @@ class GenerationMixin: ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - - # 12. prepare beam search scorer + # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -2051,7 +2015,7 @@ class GenerationMixin: max_length=generation_config.max_length, ) - # 13. interleave input_ids with `num_beams` additional sequences per batch + # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_beams, @@ -2059,12 +2023,11 @@ class GenerationMixin: **model_kwargs, ) - # 14. run beam sample + # 13. run beam sample result = self._beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2287,7 +2250,6 @@ class GenerationMixin: generation_config: GenerationConfig, synced_gpus: bool, streamer: "BaseStreamer", - logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -2316,10 +2278,6 @@ class GenerationMixin: streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2344,11 +2302,6 @@ class GenerationMixin: return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -2436,8 +2389,7 @@ class GenerationMixin: ) # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - if do_sample: # sample - next_token_scores = logits_warper(input_ids, next_token_scores) + # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: @@ -2893,7 +2845,6 @@ class GenerationMixin: generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], - logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -2916,11 +2867,6 @@ class GenerationMixin: streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in - `generation_config`) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2942,11 +2888,6 @@ class GenerationMixin: max_length = generation_config.max_length has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -2990,8 +2931,6 @@ class GenerationMixin: # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - if do_sample: - next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -3105,7 +3044,6 @@ class GenerationMixin: stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, - logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -3128,11 +3066,6 @@ class GenerationMixin: The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - logits_warper (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in - `generation_config`) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -3154,11 +3087,6 @@ class GenerationMixin: return_dict_in_generate = generation_config.return_dict_in_generate sequential = generation_config.low_memory do_sample = generation_config.do_sample - if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): - raise ValueError( - "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " - f"{logits_warper})." - ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -3249,8 +3177,6 @@ class GenerationMixin: ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) - if do_sample: - next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) @@ -3698,10 +3624,6 @@ class GenerationMixin: stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): @@ -3915,7 +3837,6 @@ class GenerationMixin: input_ids: torch.LongTensor, candidate_generator: CandidateGenerator, logits_processor: LogitsProcessorList, - logits_warper: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, @@ -3937,10 +3858,6 @@ class GenerationMixin: logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. Only used if sampling is active. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. @@ -3963,7 +3880,7 @@ class GenerationMixin: `model.config.is_encoder_decoder=True`. """ # init values - do_sample = logits_warper is not None + do_sample = generation_config.do_sample output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -4047,9 +3964,6 @@ class GenerationMixin: if len(logits_processor) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) - if do_sample and len(logits_warper) > 0: - for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index b03fd6796a..036c9caa83 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -56,9 +56,9 @@ class BarkSemanticGenerationConfig(GenerationConfig): eos_token_id (`int`, *optional*, defaults to 10_000): The id of the *end-of-sequence* token. renormalize_logits (`bool`, *optional*, defaults to `True`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the + Whether to renormalize the logits after applying all the logits processors (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the - score logits are normalized but some logit processors or warpers break the normalization. + score logits are normalized but some logit processors break the normalization. max_new_tokens (`int`, *optional*, defaults to 768): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt. output_scores (`bool`, *optional*, defaults to `False`): @@ -143,9 +143,9 @@ class BarkCoarseGenerationConfig(GenerationConfig): Args: renormalize_logits (`bool`, *optional*, defaults to `True`): - Whether to renormalize the logits after applying all the logits processors or warpers (including the + Whether to renormalize the logits after applying all the logits processors (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the - score logits are normalized but some logit processors or warpers break the normalization. + score logits are normalized but some logit processors break the normalization. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index b0e456db8a..f720faac03 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1609,13 +1609,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -1623,11 +1616,10 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2649,13 +2641,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -2664,11 +2649,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel): **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index ba19e546a1..a8a8fe9609 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1531,13 +1531,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -1545,11 +1538,10 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2490,13 +2482,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ) if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. prepare logits warper - prepared_logits_warper = ( - self._get_logits_warper(generation_config, device=input_ids.device) - if generation_config.do_sample - else None - ) - # expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, @@ -2505,11 +2490,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): **model_kwargs, ) - # 12. run sample + # 11. run sample outputs = self._sample( input_ids, logits_processor=logits_processor, - logits_warper=prepared_logits_warper, stopping_criteria=stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index d2f92bfd71..bc375b68e9 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1558,7 +1558,6 @@ class RagTokenForGeneration(RagPreTrainedModel): generation_config=generation_config, synced_gpus=False, streamer=None, - logits_warper=None, **model_kwargs, ) elif generation_config.num_beams > 1: @@ -1580,7 +1579,6 @@ class RagTokenForGeneration(RagPreTrainedModel): stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=False, - logits_warper=None, **model_kwargs, ) else: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 17f788b26e..72da44115f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -118,26 +118,24 @@ class GenerationTesterMixin: return config, input_ids, attention_mask - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = { + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = { "bad_words_ids": [[1, 0]], "repetition_penalty": 1.2, "remove_invalid_values": True, } - # NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations - if forced_bos_token_id is None and forced_eos_token_id is None: - process_kwargs["no_repeat_ngram_size"] = 2 + if do_sample: + logits_processor_kwargs.update( + { + "top_k": 10, + "top_p": 0.7, + "temperature": 0.7, + } + ) - warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} - return process_kwargs, warp_kwargs + return logits_processor_kwargs - @staticmethod - def _get_beam_kwargs(num_return_sequences=1): + def _get_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -146,8 +144,7 @@ class GenerationTesterMixin: } return beam_kwargs - @staticmethod - def _get_diverse_beam_kwargs(num_return_sequences=1): + def _get_diverse_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -158,8 +155,7 @@ class GenerationTesterMixin: } return beam_kwargs - @staticmethod - def _get_constrained_beam_kwargs(num_return_sequences=1): + def _get_constrained_beam_kwargs(self, num_return_sequences=1): beam_kwargs = { "early_stopping": False, "length_penalty": 2.0, @@ -199,12 +195,7 @@ class GenerationTesterMixin: output_hidden_states=False, return_dict_in_generate=False, ): - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -216,7 +207,7 @@ class GenerationTesterMixin: output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -228,8 +219,6 @@ class GenerationTesterMixin: input_ids, attention_mask, num_return_sequences, - logits_warper_kwargs, - process_kwargs, output_scores=False, output_logits=False, output_attentions=False, @@ -237,6 +226,7 @@ class GenerationTesterMixin: return_dict_in_generate=False, ): torch.manual_seed(0) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -249,8 +239,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, - **logits_warper_kwargs, - **process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -262,13 +251,13 @@ class GenerationTesterMixin: input_ids, attention_mask, beam_kwargs, - logits_process_kwargs, output_scores=False, output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -280,7 +269,7 @@ class GenerationTesterMixin: output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -292,7 +281,6 @@ class GenerationTesterMixin: input_ids, attention_mask, beam_kwargs, - logits_warper_kwargs, output_scores=False, output_logits=False, output_attentions=False, @@ -300,6 +288,7 @@ class GenerationTesterMixin: return_dict_in_generate=False, ): torch.manual_seed(0) + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -311,7 +300,7 @@ class GenerationTesterMixin: output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **beam_kwargs, - **logits_warper_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -323,13 +312,13 @@ class GenerationTesterMixin: input_ids, attention_mask, beam_kwargs, - logits_process_kwargs, output_scores=False, output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -341,7 +330,7 @@ class GenerationTesterMixin: output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -354,13 +343,13 @@ class GenerationTesterMixin: attention_mask, constraints, beam_kwargs, - logits_process_kwargs, output_scores=False, output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, ): + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -373,7 +362,7 @@ class GenerationTesterMixin: return_dict_in_generate=return_dict_in_generate, constraints=constraints, **beam_kwargs, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, ) @@ -395,12 +384,7 @@ class GenerationTesterMixin: "top_k": 5, } - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_generate = model.generate( input_ids, @@ -412,7 +396,7 @@ class GenerationTesterMixin: output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, - **logits_process_kwargs, + **logits_processor_kwargs, **model_kwargs, **contrastive_search_kwargs, ) @@ -495,19 +479,11 @@ class GenerationTesterMixin: config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, num_return_sequences=1, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, ) if model.config.is_encoder_decoder: @@ -521,20 +497,11 @@ class GenerationTesterMixin: config.use_cache = False model = model_class(config).to(torch_device).eval() - - process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - forced_bos_token_id=model.config.forced_bos_token_id, - forced_eos_token_id=model.config.forced_eos_token_id, - ) - output_generate = self._sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, num_return_sequences=2, - logits_warper_kwargs=logits_warper_kwargs, - process_kwargs=process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -561,19 +528,12 @@ class GenerationTesterMixin: model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) beam_kwargs = self._get_beam_kwargs() - output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: @@ -589,18 +549,12 @@ class GenerationTesterMixin: config.use_cache = False model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -633,12 +587,6 @@ class GenerationTesterMixin: self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - beam_kwargs = self._get_beam_kwargs() config.use_cache = True @@ -649,7 +597,6 @@ class GenerationTesterMixin: input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -693,17 +640,13 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() - output_generate = self._beam_sample_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, ) if model.config.is_encoder_decoder: @@ -711,7 +654,13 @@ class GenerationTesterMixin: else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): + prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters) + # `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling + # code is up to date with our most recent standards + if ( + "inputs_embeds" in prepare_inputs_for_generation_args + and "cache_positions" in prepare_inputs_for_generation_args + ): input_embeds = model.get_input_embeddings()(input_ids) beam_kwargs.update({"inputs_embeds": input_embeds}) output_generate2 = self._beam_sample_generate( @@ -719,7 +668,6 @@ class GenerationTesterMixin: input_ids=None, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, ) torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) @@ -732,7 +680,6 @@ class GenerationTesterMixin: config.use_cache = False model = model_class(config).to(torch_device).eval() - _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_sample_generate( @@ -740,7 +687,6 @@ class GenerationTesterMixin: input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -788,12 +734,6 @@ class GenerationTesterMixin: config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - # check `generate()` and `group_beam_search()` are equal beam_kwargs = self._get_diverse_beam_kwargs() output_generate = self._group_beam_search_generate( @@ -801,7 +741,6 @@ class GenerationTesterMixin: input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) @@ -816,7 +755,6 @@ class GenerationTesterMixin: input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) @@ -829,19 +767,12 @@ class GenerationTesterMixin: config.use_cache = False model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - beam_kwargs = self._get_diverse_beam_kwargs() output_generate = self._group_beam_search_generate( model=model, input_ids=input_ids, attention_mask=attention_mask, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, @@ -871,12 +802,6 @@ class GenerationTesterMixin: model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - # Sample constraints min_id = 3 max_id = config.vocab_size @@ -893,7 +818,6 @@ class GenerationTesterMixin: attention_mask=attention_mask, constraints=constraints, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: @@ -919,7 +843,6 @@ class GenerationTesterMixin: attention_mask=attention_mask, constraints=constraints, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, ) if model.config.is_encoder_decoder: @@ -938,11 +861,6 @@ class GenerationTesterMixin: config.use_cache = False model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) # Sample constraints min_id = 3 @@ -959,7 +877,6 @@ class GenerationTesterMixin: attention_mask=attention_mask, constraints=constraints, beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, output_scores=True, output_logits=True, output_hidden_states=True, diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index 1ccb2b54cc..4f1d5d6a42 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -414,10 +414,6 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip(reason="The `input_embeds` when fed don't produce the same results.") - def test_beam_sample_generate(self): - pass - @require_torch class BioGptModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index cd800da976..e7e3a7242c 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -433,6 +433,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @unittest.skip("The `input_embeds` when fed don't produce the same results.") + def test_beam_sample_generate(self): + pass + @require_torch class MambaIntegrationTests(unittest.TestCase): diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 13cc22561f..276ecf2fd6 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -283,6 +283,12 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @unittest.skip( + reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)" + ) + def test_inputs_embeds_matches_input_ids_with_generate(self): + pass + @require_torch @slow diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 7fc2f8c9db..870a4c9276 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -293,15 +293,9 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) return config, input_ids, attention_mask - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: @@ -1483,15 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, return output_generate - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 7cebf037d2..9b34f4dde6 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -296,15 +296,9 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long) return config, input_ids, attention_mask - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_stereo_outputs(self): for model_class in self.greedy_sample_model_classes: @@ -1467,15 +1461,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester return output_generate - @staticmethod - def _get_logits_processor_and_warper_kwargs( - input_length, - forced_bos_token_id=None, - forced_eos_token_id=None, - ): - process_kwargs = {} - warper_kwargs = {} - return process_kwargs, warper_kwargs + def _get_logits_processor_kwargs(self, do_sample=False): + logits_processor_kwargs = {} + return logits_processor_kwargs def test_greedy_generate_dict_outputs(self): for model_class in self.greedy_sample_model_classes: diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index ad542db273..1a58ee2970 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -413,6 +413,10 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT def test_initialization(self): pass + @unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)") + def test_inputs_embeds_matches_input_ids_with_generate(self): + pass + @require_torch_gpu @slow diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 155588ad02..6deebf552b 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -68,14 +68,7 @@ if is_torch_available(): set_seed, ) from transformers.generation import ( - BeamSampleDecoderOnlyOutput, - BeamSampleEncoderDecoderOutput, - BeamSearchDecoderOnlyOutput, - BeamSearchEncoderDecoderOutput, - GenerateBeamDecoderOnlyOutput, - GenerateBeamEncoderDecoderOutput, GenerateEncoderDecoderOutput, - PhrasalConstraint, ) from transformers.generation.logits_process import LogitsProcessor from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids @@ -419,6 +412,30 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi return False + def _get_logits_processor_kwargs(self, do_sample=False): + # Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search + logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample) + logits_processor_kwargs["temperature"] = 0.0 + return logits_processor_kwargs + + def _get_beam_kwargs(self, num_return_sequences=1): + # Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate` + beam_kwargs = super()._get_beam_kwargs(num_return_sequences=num_return_sequences) + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + return beam_kwargs + + def _get_diverse_beam_kwargs(self, num_return_sequences=1): + # Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate` + beam_kwargs = super()._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences) + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + return beam_kwargs + + def _get_constrained_beam_kwargs(self, num_return_sequences=1): + # Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate` + beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences) + beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] + return beam_kwargs + def setUp(self): self.model_tester = WhisperModelTester(self) self.config_tester = ConfigTester(self, config_class=WhisperConfig) @@ -1551,241 +1568,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_longform_generate_multi_batch_cond_prev(self): self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) - def test_beam_sample_generate_dict_output(self): - # We overwrite test_beam_sample_generate_dict_output in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - config, input_ids, attention_mask = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - model = WhisperForConditionalGeneration(config).to(torch_device).eval() - _, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1]) - beam_kwargs = self._get_beam_kwargs() - - # With Whisper, we can only perform a beam search if the temperature is set to 0. - logits_warper_kwargs["temperature"] = 0 - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - output_generate = self._beam_sample_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_warper_kwargs=logits_warper_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) - - self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]) - - def test_beam_search_generate_dict_output(self): - # We overwrite test_beam_search_generate_dict_output in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - beam_kwargs = self._get_beam_kwargs() - - # With Whisper, we can only perform a beam search if the temperature is set to 0. - logits_process_kwargs["temperature"] = 0 - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - output_generate = self._beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - - self._check_outputs( - output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] - ) - - def test_beam_search_generate_dict_outputs_use_cache(self): - # We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - for model_class in self.all_generative_model_classes: - # enable cache - config, input_ids, attention_mask = self._get_input_ids_and_config() - - if not hasattr(config, "use_cache"): - self.skipTest("This model doesn't support caching") - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - - beam_kwargs = self._get_beam_kwargs() - - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - config.use_cache = True - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - output_generate = self._beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self._check_outputs( - output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"] - ) - - def test_group_beam_search_generate_dict_output(self): - # We overwrite test_group_beam_search_generate_dict_output in test_utils as - # we can only perform beam search if the temperature is set to 0 in Whisper. - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - - beam_kwargs = self._get_diverse_beam_kwargs() - - # We will return num_beams sequences per input only if num_return_sequences == num_beams: - beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"] - - output_generate = self._group_beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - - self._check_outputs( - output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"] - ) - - def test_constrained_beam_search_generate_dict_output(self): - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - model = model_class(config).to(torch_device).eval() - logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( - input_ids.shape[-1], - config.forced_bos_token_id, - config.forced_eos_token_id, - ) - - # Sample constraints - min_id = 3 - max_id = model.config.vocab_size - force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0] - constraints = [ - PhrasalConstraint(force_tokens), - ] - - beam_kwargs = self._get_constrained_beam_kwargs() - output_generate = self._constrained_beam_search_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - constraints=constraints, - beam_kwargs=beam_kwargs, - logits_process_kwargs=logits_process_kwargs, - output_scores=True, - output_logits=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - if model.config.is_encoder_decoder: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) - self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) - else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) - self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) - # Retrocompatibility check - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - - self._check_outputs( - output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"] - ) - @is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue? def test_custom_4d_attention_mask(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index f57427c4f6..928bd332d2 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -70,6 +70,7 @@ OBJECTS_TO_IGNORE = [ # Deprecated "InputExample", "InputFeatures", + "LogitsWarper", # Signature is *args/**kwargs "TFSequenceSummary", "TFBertTokenizer", diff --git a/utils/check_repo.py b/utils/check_repo.py index 02570e3c60..acd6662cc2 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -932,6 +932,7 @@ DEPRECATED_OBJECTS = [ "LineByLineTextDataset", "LineByLineWithRefDataset", "LineByLineWithSOPTextDataset", + "LogitsWarper", "NerPipeline", "PretrainedBartModel", "PretrainedFSMTModel",