From a26de151390f5cb029b2e39231c00ad4303b4347 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Aug 2024 20:01:52 +0100 Subject: [PATCH] Generate: Deprecate returning legacy cache by default; Handle `use_cache=False` (#32863) --- .../generation/configuration_utils.py | 109 +++--- src/transformers/generation/utils.py | 349 ++++++++++-------- src/transformers/models/rwkv/modeling_rwkv.py | 3 +- tests/generation/test_utils.py | 106 +++--- 4 files changed, 311 insertions(+), 256 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index aa5e77ac68..160a8a7eae 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin): [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. penalty_alpha (`float`, *optional*): The values balance the model confidence and the degeneration penalty in contrastive search decoding. + dola_layers (`str` or `List[int]`, *optional*): + The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must + be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. + "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the + layers up to the last 20 layers. + If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. + The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, + `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) + or [the paper](https://arxiv.org/abs/2309.03883) for more details. + + > Parameters that control the cache + use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + cache_implementation (`str`, *optional*, default to `None`): + Cache class that should be used when generating. + cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): + Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and + it will be converted to its repsective `CacheConfig` internally. + Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. + return_legacy_cache (`bool`, *optional*, default to `True`): + Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. > Parameters for manipulation of the model output logits @@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin): max_matching_ngram_size (`int`, *optional*, default to `None`): The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided. - > Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883) - - dola_layers (`str` or `List[int]`, *optional*): - The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must - be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. - "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the - layers up to the last 20 layers. - If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. - The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, - `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) - or [the paper](https://arxiv.org/abs/2309.03883) for more details. - - > Parameters specific to the caching mechanism: - - cache_implementation (`str`, *optional*, default to `None`): - Cache class that should be used when generating. - cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): - Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and - it will be converted to its repsective `CacheConfig` internally. - Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. - return_legacy_cache (`bool`, *optional*, default to `True`): - Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. - > Wild card generation_kwargs: @@ -352,7 +349,19 @@ class GenerationConfig(PushToHubMixin): self.num_beams = kwargs.pop("num_beams", 1) self.num_beam_groups = kwargs.pop("num_beam_groups", 1) self.penalty_alpha = kwargs.pop("penalty_alpha", None) + self.dola_layers = kwargs.pop("dola_layers", None) + + # Parameters that control the cache self.use_cache = kwargs.pop("use_cache", True) + self.cache_implementation = kwargs.pop("cache_implementation", None) + self.cache_config = kwargs.pop("cache_config", None) + if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: + cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] + if self.cache_config is None: + self.cache_config = cache_config_class() + elif isinstance(self.cache_config, dict): + self.cache_config = cache_config_class.from_dict(self.cache_config) + self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) @@ -411,20 +420,6 @@ class GenerationConfig(PushToHubMixin): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") - # DoLa generation - self.dola_layers = kwargs.pop("dola_layers", None) - - # Cache implementation - self.cache_implementation = kwargs.pop("cache_implementation", None) - self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: - cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] - if self.cache_config is None: - self.cache_config = cache_config_class() - elif isinstance(self.cache_config, dict): - self.cache_config = cache_config_class.from_dict(self.cache_config) - self.return_legacy_cache = kwargs.pop("return_legacy_cache", True) - # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) @@ -544,8 +539,9 @@ class GenerationConfig(PushToHubMixin): raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") if self.pad_token_id is not None and self.pad_token_id < 0: warnings.warn( - f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. " - "Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values." + f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " + "generating, if there is padding. Please set `pad_token_id` explicitly as " + "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation" ) # Validation of attribute relations: @@ -675,6 +671,14 @@ class GenerationConfig(PushToHubMixin): group_error_prefix + "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical." ) + # DoLa generation + if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): + warnings.warn( + "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of " + f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for " + "DoLa decoding is `repetition_penalty>=1.2`.", + UserWarning, + ) # 4. check `num_return_sequences` if self.num_return_sequences != 1: @@ -690,7 +694,7 @@ class GenerationConfig(PushToHubMixin): f"({self.num_beams})." ) - # 5. check `cache_config` + # 5. check cache-related arguments if self.cache_config is not None: cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation) if cache_class is None: @@ -702,6 +706,20 @@ class GenerationConfig(PushToHubMixin): if not isinstance(self.cache_config, cache_class): self.cache_config = cache_class.from_dict(self.cache_config) self.cache_config.validate() + if self.use_cache is False: + # In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used + # passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error + # (otherwise a user might need to overwrite several parameters). + no_cache_warning = ( + "You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will " + "have no effect." + ) + for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): + if getattr(self, arg_name) is not None: + logger.warning_once( + no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)), + UserWarning, + ) # 6. check watermarking arguments if self.watermarking_config is not None: @@ -727,17 +745,6 @@ class GenerationConfig(PushToHubMixin): "`generate()` (or a pipeline) directly." ) - # 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2 - if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): - dola_decoding_wrong_parameter_msg = ( - "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, " - "which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`." - ) - warnings.warn( - dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty), - UserWarning, - ) - def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6ab88ad26b..a9ebdcdd47 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -176,36 +172,32 @@ class GenerateEncoderDecoderOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -228,33 +220,29 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -276,43 +264,39 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None +# TODO (joao): remove the equivalent classes and typing shortcuts below in v5 # Equivalent classes (kept for retrocompatibility purposes) GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput @@ -1501,6 +1486,121 @@ class GenerationMixin: """ return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: "PreTrainedModel", + batch_size: int, + device: torch.device, + ) -> bool: + """ + Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is + instantiated, writes it to `model_kwargs`, under the name expected by the model. + """ + + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + # Quick escape route 1: if the user specifies a cache, we only need to: + # a) check for conflicting `generate` arguments + # b) convert to the new cache format (if the user passes a legacy cache and model supports it) + user_defined_cache = model_kwargs.get(cache_name) + if user_defined_cache is not None: + if generation_config.cache_implementation is not None: + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(user_defined_cache) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(user_defined_cache) + ) + return + + # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in + # `generation_config.validate()`) + if generation_config.use_cache is False: + return + + # Quick escape route 3: model that only supports legacy caches = nothing to prepare + if not self._supports_default_dynamic_cache(): + if generation_config.cache_implementation is not None: + warnings.warn( + "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " + f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " + "ignored.", + UserWarning, + ) + return + + # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` + + # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, + # which is only supported in dynamic caches atm + if assistant_model is not None and generation_config.cache_implementation is not None: + logger.warning_once( + "An assistant model is provided, using a dynamic cache instead of a cache of type=" + f"'{generation_config.cache_implementation}'." + ) + generation_config.cache_implementation = None + + if generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + cache_implementation=generation_config.cache_implementation, + batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, + max_cache_len=generation_config.max_length, + device=device, + model_kwargs=model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue and tag @zucchini-nlp." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + elif generation_config.cache_implementation == "offloaded": + model_kwargs[cache_name] = OffloadedCache() + + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + else: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + def _prepare_special_tokens( self, generation_config: GenerationConfig, @@ -1776,104 +1876,18 @@ class GenerationMixin: inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) - - use_dynamic_cache_by_default = False - if "mamba" in self.__class__.__name__.lower(): - cache_name = "cache_params" - else: - cache_name = "past_key_values" - - # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, - # which is only supported in dynamic caches atm - if ( - assistant_model is not None - and generation_config.cache_implementation is not None - and self._supports_default_dynamic_cache() - ): - logger.warning_once( - "An assistant model is provided, using a dynamic cache instead of a cache of type=" - f"'{generation_config.cache_implementation}'." - ) - generation_config.cache_implementation = None - - if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling(): - raise ValueError( - "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " - "may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` " - "input argument." - ) - if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): - raise ValueError( - f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " - "Cache object) is unsupported. Please use only one of the two." - ) - elif generation_config.cache_implementation is not None: - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if generation_config.cache_implementation == "static" and not self._supports_static_cache: - raise ValueError( - "This model does not support `cache_implementation='static'`. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981" - ) - model_kwargs[cache_name] = self._get_cache( - cache_implementation=generation_config.cache_implementation, - batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, - max_cache_len=generation_config.max_length, - device=device, - model_kwargs=model_kwargs, - ) - elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: - raise ValueError( - "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue." - ) - - cache_config = ( - generation_config.cache_config - if generation_config.cache_config is not None - else QuantizedCacheConfig() - ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] - - if cache_config.backend == "quanto" and not is_quanto_available(): - raise ImportError( - "You need to install `quanto` in order to use KV cache quantization with quanto backend. " - "Please install it via with `pip install quanto`" - ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): - raise ImportError( - "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " - "Please install it via with `pip install hqq`" - ) - - model_kwargs[cache_name] = cache_class(cache_config) - elif generation_config.cache_implementation == "offloaded": - model_kwargs[cache_name] = OffloadedCache() - # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that - # keeps copying the cache thus using much more memory - elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): - past = model_kwargs.get(cache_name, None) - requires_cross_attention_cache = ( - self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None - ) - if past is None: - model_kwargs[cache_name] = ( - DynamicCache() - if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(), DynamicCache()) - ) - use_dynamic_cache_by_default = True - elif isinstance(past, tuple): - model_kwargs[cache_name] = ( - DynamicCache.from_legacy_cache(past) - if not requires_cross_attention_cache - else EncoderDecoderCache.from_legacy_cache(past) - ) - use_dynamic_cache_by_default = True - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - # 7. determine generation mode + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + user_defined_cache = model_kwargs.get(cache_name) + self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device) + + # 8. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) if streamer is not None and (generation_config.num_beams > 1): @@ -1892,7 +1906,7 @@ class GenerationMixin: UserWarning, ) - # 8. prepare distribution pre_processing samplers + # 9. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -1904,8 +1918,6 @@ class GenerationMixin: negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) - - # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) @@ -2138,11 +2150,34 @@ class GenerationMixin: **model_kwargs, ) - # Convert to legacy cache if needed - if use_dynamic_cache_by_default and generation_config.return_legacy_cache: - if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): - if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)): - result.past_key_values = result.past_key_values.to_legacy_cache() + # Convert to legacy cache format if requested + if ( + generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 + and not is_torchdynamo_compiling() + and hasattr(result, "past_key_values") + and hasattr(result.past_key_values, "to_legacy_cache") + and result.past_key_values.to_legacy_cache is not None + ): + # handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type) + should_convert_cache = generation_config.return_legacy_cache + is_user_defined_cache = user_defined_cache is not None + is_default_cache_type = ( + type(result.past_key_values) == DynamicCache # noqa E721 + or ( + isinstance(result.past_key_values, EncoderDecoderCache) + and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 + and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721 + ) + ) + if not is_user_defined_cache and is_default_cache_type: + logger.warning_once( + "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " + "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " + "keep returning the legacy format, please set `return_legacy_cache=True`." + ) + should_convert_cache = True + if should_convert_cache: + result.past_key_values = result.past_key_values.to_legacy_cache() return result def _has_unfinished_sequences( diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index f6b8cd412b..7dec1f26e1 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -768,7 +768,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs): # only last token for inputs_ids if the state is passed along. if state is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -780,6 +780,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel): model_inputs = {"input_ids": input_ids} model_inputs["state"] = state + model_inputs["use_cache"] = use_cache return model_inputs @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 72da44115f..ae52f6c674 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -194,6 +194,7 @@ class GenerationTesterMixin: output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -207,6 +208,7 @@ class GenerationTesterMixin: output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, ) @@ -224,6 +226,7 @@ class GenerationTesterMixin: output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) @@ -239,6 +242,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, ) @@ -256,6 +260,7 @@ class GenerationTesterMixin: output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -268,6 +273,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -286,6 +292,7 @@ class GenerationTesterMixin: output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) @@ -299,6 +306,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -317,6 +325,7 @@ class GenerationTesterMixin: output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -329,6 +338,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -348,6 +358,7 @@ class GenerationTesterMixin: output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -361,6 +372,7 @@ class GenerationTesterMixin: output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, constraints=constraints, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -378,6 +390,7 @@ class GenerationTesterMixin: output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): contrastive_search_kwargs = { "penalty_alpha": 0.6, @@ -396,6 +409,7 @@ class GenerationTesterMixin: output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, **contrastive_search_kwargs, @@ -419,7 +433,6 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, @@ -430,6 +443,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -454,7 +468,6 @@ class GenerationTesterMixin: if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( @@ -466,6 +479,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -495,7 +509,6 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( model=model, @@ -507,6 +520,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -545,9 +559,6 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( @@ -560,6 +571,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) @@ -589,7 +601,6 @@ class GenerationTesterMixin: model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._beam_search_generate( @@ -602,6 +613,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -676,9 +688,6 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -692,6 +701,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -764,7 +774,6 @@ class GenerationTesterMixin: def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_diverse_beam_kwargs() @@ -778,6 +787,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) @@ -857,9 +867,6 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() # Sample constraints @@ -882,6 +889,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -913,13 +921,12 @@ class GenerationTesterMixin: # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() output_generate = self._contrastive_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask + model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) @@ -940,7 +947,6 @@ class GenerationTesterMixin: # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -953,6 +959,7 @@ class GenerationTesterMixin: output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -978,7 +985,6 @@ class GenerationTesterMixin: if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True # test output equality of low versus high memory @@ -991,6 +997,7 @@ class GenerationTesterMixin: low_memory=True, max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + use_cache=True, ) high_output = model.generate( @@ -1000,6 +1007,7 @@ class GenerationTesterMixin: low_memory=False, max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1031,10 +1039,17 @@ class GenerationTesterMixin: # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() - low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True) + low_output = model.generate( + input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True + ) high_output = model.generate( - input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False + input_ids, + max_new_tokens=8, + num_beams=5, + early_stopping=True, + low_memory=False, + use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1079,7 +1094,6 @@ class GenerationTesterMixin: if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1098,6 +1112,7 @@ class GenerationTesterMixin: "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1150,7 +1165,6 @@ class GenerationTesterMixin: if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1169,6 +1183,7 @@ class GenerationTesterMixin: "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1196,12 +1211,6 @@ class GenerationTesterMixin: # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm config, input_ids, attention_mask = self._get_input_ids_and_config() - # Some models don't support the cache and returning past_key_values - if not hasattr(config, "use_cache"): - config.use_cache = False - else: - config.use_cache = True - # Encoder-decoder models are not supported if config.is_encoder_decoder: self.skipTest("DoLa is not supported for encoder-decoder models") @@ -1224,11 +1233,12 @@ class GenerationTesterMixin: "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": hasattr(config, "use_cache"), # Some models don't support the cache } generation_kwargs.update({"dola_layers": "low"}) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs) - self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache) + self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache")) def test_assisted_decoding_sample(self): # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not @@ -1261,7 +1271,6 @@ class GenerationTesterMixin: if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1284,6 +1293,7 @@ class GenerationTesterMixin: "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1566,7 +1576,6 @@ class GenerationTesterMixin: # 3. ignore `token_type_ids` for simplicity # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is # active by default on some models - config.use_cache = True if "token_type_ids" in inputs: del inputs["token_type_ids"] @@ -1574,6 +1583,7 @@ class GenerationTesterMixin: model.eval() model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.forced_eos_token_id = None + model.generation_config.use_cache = True # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) outputs = model(**inputs) @@ -1631,7 +1641,6 @@ class GenerationTesterMixin: self.skipTest(reason="This model does not support the new cache format") config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = True model = model_class(config).to(torch_device).eval() generation_kwargs = { @@ -1640,6 +1649,7 @@ class GenerationTesterMixin: "num_beams": num_beams, "num_return_sequences": num_beams, "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } # Sets seed before calling `generate` for the case with do_sample=True @@ -1701,7 +1711,6 @@ class GenerationTesterMixin: if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") - config.use_cache = True config.is_decoder = True batch_size, seq_length = input_ids.shape max_new_tokens = 20 @@ -1712,6 +1721,7 @@ class GenerationTesterMixin: "max_new_tokens": max_new_tokens, "cache_implementation": "static", "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } max_cache_len = seq_length + max_new_tokens @@ -1740,7 +1750,6 @@ class GenerationTesterMixin: self.skipTest(reason="This model does not support the quantized cache format") config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -1750,6 +1759,7 @@ class GenerationTesterMixin: # careful with group size, should be divisor of model's hidden size "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1890,22 +1900,24 @@ class GenerationTesterMixin: # Past Key Value States -- a few notes here: # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" - # 2. Some old models still return `output.past_key_values` even without `use_cache=True` - # 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is - # complete - models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba") + # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the + # standard cache format (e.g.gptbigcode ) + models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet") has_standard_cache = not any( model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache ) - if use_cache and has_standard_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 - self._check_past_key_values_for_generate( - num_sequences_in_output, - past_key_values, - seq_length=past_sequence_length, - config=config, - ) + if has_standard_cache: + if use_cache: + past_key_values = output.past_key_values + past_sequence_length = output.sequences.shape[-1] - 1 + self._check_past_key_values_for_generate( + num_sequences_in_output, + past_key_values, + seq_length=past_sequence_length, + config=config, + ) + elif use_cache is False: + self.assertTrue(output.past_key_values is None) def _check_scores(self, batch_size, scores, length, config): expected_shape = (batch_size, config.vocab_size)