Generate: Deprecate returning legacy cache by default; Handle use_cache=False (#32863)
This commit is contained in:
@@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||||
penalty_alpha (`float`, *optional*):
|
penalty_alpha (`float`, *optional*):
|
||||||
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
|
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
|
||||||
|
dola_layers (`str` or `List[int]`, *optional*):
|
||||||
|
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
|
||||||
|
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
|
||||||
|
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
|
||||||
|
layers up to the last 20 layers.
|
||||||
|
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
|
||||||
|
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
|
||||||
|
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
|
||||||
|
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
|
||||||
|
|
||||||
|
> Parameters that control the cache
|
||||||
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||||
speed up decoding.
|
speed up decoding.
|
||||||
|
cache_implementation (`str`, *optional*, default to `None`):
|
||||||
|
Cache class that should be used when generating.
|
||||||
|
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
||||||
|
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
||||||
|
it will be converted to its repsective `CacheConfig` internally.
|
||||||
|
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
|
||||||
|
return_legacy_cache (`bool`, *optional*, default to `True`):
|
||||||
|
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
|
||||||
|
|
||||||
> Parameters for manipulation of the model output logits
|
> Parameters for manipulation of the model output logits
|
||||||
|
|
||||||
@@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
max_matching_ngram_size (`int`, *optional*, default to `None`):
|
max_matching_ngram_size (`int`, *optional*, default to `None`):
|
||||||
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
|
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.
|
||||||
|
|
||||||
> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)
|
|
||||||
|
|
||||||
dola_layers (`str` or `List[int]`, *optional*):
|
|
||||||
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
|
|
||||||
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
|
|
||||||
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
|
|
||||||
layers up to the last 20 layers.
|
|
||||||
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
|
|
||||||
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
|
|
||||||
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
|
|
||||||
or [the paper](https://arxiv.org/abs/2309.03883) for more details.
|
|
||||||
|
|
||||||
> Parameters specific to the caching mechanism:
|
|
||||||
|
|
||||||
cache_implementation (`str`, *optional*, default to `None`):
|
|
||||||
Cache class that should be used when generating.
|
|
||||||
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
|
|
||||||
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
|
|
||||||
it will be converted to its repsective `CacheConfig` internally.
|
|
||||||
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
|
|
||||||
return_legacy_cache (`bool`, *optional*, default to `True`):
|
|
||||||
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
|
|
||||||
|
|
||||||
> Wild card
|
> Wild card
|
||||||
|
|
||||||
generation_kwargs:
|
generation_kwargs:
|
||||||
@@ -352,7 +349,19 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.num_beams = kwargs.pop("num_beams", 1)
|
self.num_beams = kwargs.pop("num_beams", 1)
|
||||||
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
||||||
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
|
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
|
||||||
|
self.dola_layers = kwargs.pop("dola_layers", None)
|
||||||
|
|
||||||
|
# Parameters that control the cache
|
||||||
self.use_cache = kwargs.pop("use_cache", True)
|
self.use_cache = kwargs.pop("use_cache", True)
|
||||||
|
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
||||||
|
self.cache_config = kwargs.pop("cache_config", None)
|
||||||
|
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
|
||||||
|
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
|
||||||
|
if self.cache_config is None:
|
||||||
|
self.cache_config = cache_config_class()
|
||||||
|
elif isinstance(self.cache_config, dict):
|
||||||
|
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
||||||
|
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
||||||
|
|
||||||
# Parameters for manipulation of the model output logits
|
# Parameters for manipulation of the model output logits
|
||||||
self.temperature = kwargs.pop("temperature", 1.0)
|
self.temperature = kwargs.pop("temperature", 1.0)
|
||||||
@@ -411,20 +420,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
||||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
||||||
|
|
||||||
# DoLa generation
|
|
||||||
self.dola_layers = kwargs.pop("dola_layers", None)
|
|
||||||
|
|
||||||
# Cache implementation
|
|
||||||
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
|
||||||
self.cache_config = kwargs.pop("cache_config", None)
|
|
||||||
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
|
|
||||||
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
|
|
||||||
if self.cache_config is None:
|
|
||||||
self.cache_config = cache_config_class()
|
|
||||||
elif isinstance(self.cache_config, dict):
|
|
||||||
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
|
||||||
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)
|
|
||||||
|
|
||||||
# Prompt lookup decoding
|
# Prompt lookup decoding
|
||||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||||
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
||||||
@@ -544,8 +539,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
||||||
if self.pad_token_id is not None and self.pad_token_id < 0:
|
if self.pad_token_id is not None and self.pad_token_id < 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. "
|
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
|
||||||
"Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values."
|
"generating, if there is padding. Please set `pad_token_id` explicitly as "
|
||||||
|
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validation of attribute relations:
|
# Validation of attribute relations:
|
||||||
@@ -675,6 +671,14 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
group_error_prefix
|
group_error_prefix
|
||||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||||
)
|
)
|
||||||
|
# DoLa generation
|
||||||
|
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||||
|
warnings.warn(
|
||||||
|
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||||
|
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||||
|
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
# 4. check `num_return_sequences`
|
# 4. check `num_return_sequences`
|
||||||
if self.num_return_sequences != 1:
|
if self.num_return_sequences != 1:
|
||||||
@@ -690,7 +694,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
f"({self.num_beams})."
|
f"({self.num_beams})."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. check `cache_config`
|
# 5. check cache-related arguments
|
||||||
if self.cache_config is not None:
|
if self.cache_config is not None:
|
||||||
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
|
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
|
||||||
if cache_class is None:
|
if cache_class is None:
|
||||||
@@ -702,6 +706,20 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
if not isinstance(self.cache_config, cache_class):
|
if not isinstance(self.cache_config, cache_class):
|
||||||
self.cache_config = cache_class.from_dict(self.cache_config)
|
self.cache_config = cache_class.from_dict(self.cache_config)
|
||||||
self.cache_config.validate()
|
self.cache_config.validate()
|
||||||
|
if self.use_cache is False:
|
||||||
|
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
||||||
|
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
||||||
|
# (otherwise a user might need to overwrite several parameters).
|
||||||
|
no_cache_warning = (
|
||||||
|
"You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
|
||||||
|
"have no effect."
|
||||||
|
)
|
||||||
|
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
|
||||||
|
if getattr(self, arg_name) is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
# 6. check watermarking arguments
|
# 6. check watermarking arguments
|
||||||
if self.watermarking_config is not None:
|
if self.watermarking_config is not None:
|
||||||
@@ -727,17 +745,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
"`generate()` (or a pipeline) directly."
|
"`generate()` (or a pipeline) directly."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
|
|
||||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
|
||||||
dola_decoding_wrong_parameter_msg = (
|
|
||||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
|
|
||||||
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
|
|
||||||
)
|
|
||||||
warnings.warn(
|
|
||||||
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
|
|||||||
@@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||||
if all batches finished early due to the `eos_token_id`.
|
if all batches finished early due to the `eos_token_id`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
||||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
|
||||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
|
||||||
encoder_sequence_length, embed_size_per_head)`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -176,36 +172,32 @@ class GenerateEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
||||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||||
if all batches finished early due to the `eos_token_id`.
|
if all batches finished early due to the `eos_token_id`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length, sequence_length)`.
|
||||||
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
shape `(batch_size, sequence_length, hidden_size)`.
|
shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
|
||||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
|
||||||
encoder_sequence_length, embed_size_per_head)`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -228,33 +220,29 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput):
|
|||||||
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
||||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||||
if all batches finished early due to the `eos_token_id`.
|
if all batches finished early due to the `eos_token_id`.
|
||||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
|
||||||
Final beam scores of the generated `sequences`.
|
Final beam scores of the generated `sequences`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||||
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
||||||
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
||||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||||
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
|
||||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||||
`(batch_size*num_return_sequences, sequence_length)`.
|
`(batch_size*num_return_sequences, sequence_length)`.
|
||||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
||||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
|
||||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
|
||||||
encoder_sequence_length, embed_size_per_head)`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -276,43 +264,39 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
|
|||||||
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
|
||||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||||
if all batches finished early due to the `eos_token_id`.
|
if all batches finished early due to the `eos_token_id`.
|
||||||
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
|
||||||
Final beam scores of the generated `sequences`.
|
Final beam scores of the generated `sequences`.
|
||||||
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
|
||||||
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
|
||||||
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
|
||||||
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
|
||||||
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
|
||||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
|
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||||
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
|
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
|
||||||
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
|
||||||
`(batch_size*num_return_sequences, sequence_length)`.
|
`(batch_size*num_return_sequences, sequence_length)`.
|
||||||
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length, sequence_length)`.
|
||||||
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
|
shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
|
||||||
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
|
||||||
sequence_length)`.
|
sequence_length)`.
|
||||||
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
|
cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||||
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
`torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
|
||||||
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
|
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
|
||||||
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
|
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
|
||||||
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
|
||||||
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
|
|
||||||
encoder_sequence_length, embed_size_per_head)`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: torch.LongTensor = None
|
sequences: torch.LongTensor = None
|
||||||
@@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
|
|||||||
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (joao): remove the equivalent classes and typing shortcuts below in v5
|
||||||
# Equivalent classes (kept for retrocompatibility purposes)
|
# Equivalent classes (kept for retrocompatibility purposes)
|
||||||
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
||||||
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput
|
||||||
@@ -1501,6 +1486,121 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()
|
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()
|
||||||
|
|
||||||
|
def _prepare_cache_for_generation(
|
||||||
|
self,
|
||||||
|
generation_config: GenerationConfig,
|
||||||
|
model_kwargs: Dict,
|
||||||
|
assistant_model: "PreTrainedModel",
|
||||||
|
batch_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is
|
||||||
|
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||||
|
requires_cross_attention_cache = (
|
||||||
|
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Quick escape route 1: if the user specifies a cache, we only need to:
|
||||||
|
# a) check for conflicting `generate` arguments
|
||||||
|
# b) convert to the new cache format (if the user passes a legacy cache and model supports it)
|
||||||
|
user_defined_cache = model_kwargs.get(cache_name)
|
||||||
|
if user_defined_cache is not None:
|
||||||
|
if generation_config.cache_implementation is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
|
||||||
|
"Cache object) is unsupported. Please use only one of the two."
|
||||||
|
)
|
||||||
|
if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache():
|
||||||
|
model_kwargs[cache_name] = (
|
||||||
|
DynamicCache.from_legacy_cache(user_defined_cache)
|
||||||
|
if not requires_cross_attention_cache
|
||||||
|
else EncoderDecoderCache.from_legacy_cache(user_defined_cache)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
|
||||||
|
# `generation_config.validate()`)
|
||||||
|
if generation_config.use_cache is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Quick escape route 3: model that only supports legacy caches = nothing to prepare
|
||||||
|
if not self._supports_default_dynamic_cache():
|
||||||
|
if generation_config.cache_implementation is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"This model does not support `Cache` instances, it only supports the legacy cache format (tuple "
|
||||||
|
f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be "
|
||||||
|
"ignored.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
|
||||||
|
|
||||||
|
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
||||||
|
# which is only supported in dynamic caches atm
|
||||||
|
if assistant_model is not None and generation_config.cache_implementation is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
||||||
|
f"'{generation_config.cache_implementation}'."
|
||||||
|
)
|
||||||
|
generation_config.cache_implementation = None
|
||||||
|
|
||||||
|
if generation_config.cache_implementation is not None:
|
||||||
|
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||||
|
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not support `cache_implementation='static'`. Please check the following "
|
||||||
|
"issue: https://github.com/huggingface/transformers/issues/28981"
|
||||||
|
)
|
||||||
|
model_kwargs[cache_name] = self._get_cache(
|
||||||
|
cache_implementation=generation_config.cache_implementation,
|
||||||
|
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
|
||||||
|
max_cache_len=generation_config.max_length,
|
||||||
|
device=device,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
elif generation_config.cache_implementation == "quantized":
|
||||||
|
if not self._supports_quantized_cache:
|
||||||
|
raise ValueError(
|
||||||
|
"This model does not support the quantized cache. If you want your model to support quantized "
|
||||||
|
"cache, please open an issue and tag @zucchini-nlp."
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_config = (
|
||||||
|
generation_config.cache_config
|
||||||
|
if generation_config.cache_config is not None
|
||||||
|
else QuantizedCacheConfig()
|
||||||
|
)
|
||||||
|
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
|
||||||
|
|
||||||
|
if cache_config.backend == "quanto" and not is_quanto_available():
|
||||||
|
raise ImportError(
|
||||||
|
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
|
||||||
|
"Please install it via with `pip install quanto`"
|
||||||
|
)
|
||||||
|
elif cache_config.backend == "HQQ" and not is_hqq_available():
|
||||||
|
raise ImportError(
|
||||||
|
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
||||||
|
"Please install it via with `pip install hqq`"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_kwargs[cache_name] = cache_class(cache_config)
|
||||||
|
elif generation_config.cache_implementation == "offloaded":
|
||||||
|
model_kwargs[cache_name] = OffloadedCache()
|
||||||
|
|
||||||
|
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
||||||
|
# keeps copying the cache thus using much more memory
|
||||||
|
else:
|
||||||
|
model_kwargs[cache_name] = (
|
||||||
|
DynamicCache()
|
||||||
|
if not requires_cross_attention_cache
|
||||||
|
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_special_tokens(
|
def _prepare_special_tokens(
|
||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
@@ -1776,104 +1876,18 @@ class GenerationMixin:
|
|||||||
inputs_tensor=inputs_tensor,
|
inputs_tensor=inputs_tensor,
|
||||||
input_ids_length=input_ids_length,
|
input_ids_length=input_ids_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_dynamic_cache_by_default = False
|
|
||||||
if "mamba" in self.__class__.__name__.lower():
|
|
||||||
cache_name = "cache_params"
|
|
||||||
else:
|
|
||||||
cache_name = "past_key_values"
|
|
||||||
|
|
||||||
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
|
|
||||||
# which is only supported in dynamic caches atm
|
|
||||||
if (
|
|
||||||
assistant_model is not None
|
|
||||||
and generation_config.cache_implementation is not None
|
|
||||||
and self._supports_default_dynamic_cache()
|
|
||||||
):
|
|
||||||
logger.warning_once(
|
|
||||||
"An assistant model is provided, using a dynamic cache instead of a cache of type="
|
|
||||||
f"'{generation_config.cache_implementation}'."
|
|
||||||
)
|
|
||||||
generation_config.cache_implementation = None
|
|
||||||
|
|
||||||
if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling():
|
|
||||||
raise ValueError(
|
|
||||||
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
|
|
||||||
"may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` "
|
|
||||||
"input argument."
|
|
||||||
)
|
|
||||||
if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None):
|
|
||||||
raise ValueError(
|
|
||||||
f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a "
|
|
||||||
"Cache object) is unsupported. Please use only one of the two."
|
|
||||||
)
|
|
||||||
elif generation_config.cache_implementation is not None:
|
|
||||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
|
||||||
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
|
|
||||||
raise ValueError(
|
|
||||||
"This model does not support `cache_implementation='static'`. Please check the following "
|
|
||||||
"issue: https://github.com/huggingface/transformers/issues/28981"
|
|
||||||
)
|
|
||||||
model_kwargs[cache_name] = self._get_cache(
|
|
||||||
cache_implementation=generation_config.cache_implementation,
|
|
||||||
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
|
|
||||||
max_cache_len=generation_config.max_length,
|
|
||||||
device=device,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
)
|
|
||||||
elif generation_config.cache_implementation == "quantized":
|
|
||||||
if not self._supports_quantized_cache:
|
|
||||||
raise ValueError(
|
|
||||||
"This model does not support the quantized cache. If you want your model to support quantized "
|
|
||||||
"cache, please open an issue."
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_config = (
|
|
||||||
generation_config.cache_config
|
|
||||||
if generation_config.cache_config is not None
|
|
||||||
else QuantizedCacheConfig()
|
|
||||||
)
|
|
||||||
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
|
|
||||||
|
|
||||||
if cache_config.backend == "quanto" and not is_quanto_available():
|
|
||||||
raise ImportError(
|
|
||||||
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
|
|
||||||
"Please install it via with `pip install quanto`"
|
|
||||||
)
|
|
||||||
elif cache_config.backend == "HQQ" and not is_hqq_available():
|
|
||||||
raise ImportError(
|
|
||||||
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
|
|
||||||
"Please install it via with `pip install hqq`"
|
|
||||||
)
|
|
||||||
|
|
||||||
model_kwargs[cache_name] = cache_class(cache_config)
|
|
||||||
elif generation_config.cache_implementation == "offloaded":
|
|
||||||
model_kwargs[cache_name] = OffloadedCache()
|
|
||||||
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
|
|
||||||
# keeps copying the cache thus using much more memory
|
|
||||||
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
|
||||||
past = model_kwargs.get(cache_name, None)
|
|
||||||
requires_cross_attention_cache = (
|
|
||||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
|
||||||
)
|
|
||||||
if past is None:
|
|
||||||
model_kwargs[cache_name] = (
|
|
||||||
DynamicCache()
|
|
||||||
if not requires_cross_attention_cache
|
|
||||||
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
|
||||||
)
|
|
||||||
use_dynamic_cache_by_default = True
|
|
||||||
elif isinstance(past, tuple):
|
|
||||||
model_kwargs[cache_name] = (
|
|
||||||
DynamicCache.from_legacy_cache(past)
|
|
||||||
if not requires_cross_attention_cache
|
|
||||||
else EncoderDecoderCache.from_legacy_cache(past)
|
|
||||||
)
|
|
||||||
use_dynamic_cache_by_default = True
|
|
||||||
|
|
||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
|
|
||||||
# 7. determine generation mode
|
# 7. Prepare the cache.
|
||||||
|
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||||
|
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||||
|
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||||
|
# TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format)
|
||||||
|
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||||
|
user_defined_cache = model_kwargs.get(cache_name)
|
||||||
|
self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device)
|
||||||
|
|
||||||
|
# 8. determine generation mode
|
||||||
generation_mode = generation_config.get_generation_mode(assistant_model)
|
generation_mode = generation_config.get_generation_mode(assistant_model)
|
||||||
|
|
||||||
if streamer is not None and (generation_config.num_beams > 1):
|
if streamer is not None and (generation_config.num_beams > 1):
|
||||||
@@ -1892,7 +1906,7 @@ class GenerationMixin:
|
|||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. prepare distribution pre_processing samplers
|
# 9. prepare logits processors and stopping criteria
|
||||||
prepared_logits_processor = self._get_logits_processor(
|
prepared_logits_processor = self._get_logits_processor(
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
input_ids_seq_length=input_ids_length,
|
input_ids_seq_length=input_ids_length,
|
||||||
@@ -1904,8 +1918,6 @@ class GenerationMixin:
|
|||||||
negative_prompt_ids=negative_prompt_ids,
|
negative_prompt_ids=negative_prompt_ids,
|
||||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 9. prepare stopping criteria
|
|
||||||
prepared_stopping_criteria = self._get_stopping_criteria(
|
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||||
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
||||||
)
|
)
|
||||||
@@ -2138,10 +2150,33 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to legacy cache if needed
|
# Convert to legacy cache format if requested
|
||||||
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
|
if (
|
||||||
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
|
generation_config.return_legacy_cache is not False # Should check for `True` after v4.47
|
||||||
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
|
and not is_torchdynamo_compiling()
|
||||||
|
and hasattr(result, "past_key_values")
|
||||||
|
and hasattr(result.past_key_values, "to_legacy_cache")
|
||||||
|
and result.past_key_values.to_legacy_cache is not None
|
||||||
|
):
|
||||||
|
# handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type)
|
||||||
|
should_convert_cache = generation_config.return_legacy_cache
|
||||||
|
is_user_defined_cache = user_defined_cache is not None
|
||||||
|
is_default_cache_type = (
|
||||||
|
type(result.past_key_values) == DynamicCache # noqa E721
|
||||||
|
or (
|
||||||
|
isinstance(result.past_key_values, EncoderDecoderCache)
|
||||||
|
and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721
|
||||||
|
and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not is_user_defined_cache and is_default_cache_type:
|
||||||
|
logger.warning_once(
|
||||||
|
"From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` "
|
||||||
|
"instance instead by default (as opposed to the legacy tuple of tuples format). If you want to "
|
||||||
|
"keep returning the legacy format, please set `return_legacy_cache=True`."
|
||||||
|
)
|
||||||
|
should_convert_cache = True
|
||||||
|
if should_convert_cache:
|
||||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -768,7 +768,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.head = new_embeddings
|
self.head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs):
|
||||||
# only last token for inputs_ids if the state is passed along.
|
# only last token for inputs_ids if the state is passed along.
|
||||||
if state is not None:
|
if state is not None:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
@@ -780,6 +780,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
|
|||||||
model_inputs = {"input_ids": input_ids}
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
model_inputs["state"] = state
|
model_inputs["state"] = state
|
||||||
|
model_inputs["use_cache"] = use_cache
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
@@ -207,6 +208,7 @@ class GenerationTesterMixin:
|
|||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
use_cache=use_cache,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -224,6 +226,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||||
@@ -239,6 +242,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
use_cache=use_cache,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -256,6 +260,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
@@ -268,6 +273,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
use_cache=use_cache,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -286,6 +292,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||||
@@ -299,6 +306,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
use_cache=use_cache,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -317,6 +325,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
@@ -329,6 +338,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
use_cache=use_cache,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -348,6 +358,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
@@ -361,6 +372,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
|
use_cache=use_cache,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -378,6 +390,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
|
use_cache=True,
|
||||||
):
|
):
|
||||||
contrastive_search_kwargs = {
|
contrastive_search_kwargs = {
|
||||||
"penalty_alpha": 0.6,
|
"penalty_alpha": 0.6,
|
||||||
@@ -396,6 +409,7 @@ class GenerationTesterMixin:
|
|||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
|
use_cache=use_cache,
|
||||||
**logits_processor_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
**contrastive_search_kwargs,
|
**contrastive_search_kwargs,
|
||||||
@@ -419,7 +433,6 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -430,6 +443,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -454,7 +468,6 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
@@ -466,6 +479,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -495,7 +509,6 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
config.use_cache = False
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -507,6 +520,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -545,9 +559,6 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# disable cache
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
output_generate = self._beam_search_generate(
|
output_generate = self._beam_search_generate(
|
||||||
@@ -560,6 +571,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
@@ -589,7 +601,6 @@ class GenerationTesterMixin:
|
|||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._beam_search_generate(
|
output_generate = self._beam_search_generate(
|
||||||
@@ -602,6 +613,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -676,9 +688,6 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# disable cache
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
@@ -692,6 +701,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -764,7 +774,6 @@ class GenerationTesterMixin:
|
|||||||
def test_group_beam_search_generate_dict_output(self):
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||||
@@ -778,6 +787,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
@@ -857,9 +867,6 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# disable cache
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
@@ -882,6 +889,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -913,13 +921,12 @@ class GenerationTesterMixin:
|
|||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
# test old generation output for backwards compatibility
|
# test old generation output for backwards compatibility
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._contrastive_generate(
|
output_generate = self._contrastive_generate(
|
||||||
model=model, input_ids=input_ids, attention_mask=attention_mask
|
model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
@@ -940,7 +947,6 @@ class GenerationTesterMixin:
|
|||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@@ -953,6 +959,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=self.has_attentions,
|
output_attentions=self.has_attentions,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -978,7 +985,6 @@ class GenerationTesterMixin:
|
|||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
# test output equality of low versus high memory
|
# test output equality of low versus high memory
|
||||||
@@ -991,6 +997,7 @@ class GenerationTesterMixin:
|
|||||||
low_memory=True,
|
low_memory=True,
|
||||||
max_new_tokens=self.max_new_tokens,
|
max_new_tokens=self.max_new_tokens,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
high_output = model.generate(
|
high_output = model.generate(
|
||||||
@@ -1000,6 +1007,7 @@ class GenerationTesterMixin:
|
|||||||
low_memory=False,
|
low_memory=False,
|
||||||
max_new_tokens=self.max_new_tokens,
|
max_new_tokens=self.max_new_tokens,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
@@ -1031,10 +1039,17 @@ class GenerationTesterMixin:
|
|||||||
# test output equality of low versus high memory
|
# test output equality of low versus high memory
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
|
low_output = model.generate(
|
||||||
|
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
high_output = model.generate(
|
high_output = model.generate(
|
||||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
|
input_ids,
|
||||||
|
max_new_tokens=8,
|
||||||
|
num_beams=5,
|
||||||
|
early_stopping=True,
|
||||||
|
low_memory=False,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
@@ -1079,7 +1094,6 @@ class GenerationTesterMixin:
|
|||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
# Sets assisted generation arguments such that:
|
# Sets assisted generation arguments such that:
|
||||||
@@ -1098,6 +1112,7 @@ class GenerationTesterMixin:
|
|||||||
"output_hidden_states": True,
|
"output_hidden_states": True,
|
||||||
"output_attentions": self.has_attentions,
|
"output_attentions": self.has_attentions,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
|
||||||
@@ -1150,7 +1165,6 @@ class GenerationTesterMixin:
|
|||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
# Sets assisted generation arguments such that:
|
# Sets assisted generation arguments such that:
|
||||||
@@ -1169,6 +1183,7 @@ class GenerationTesterMixin:
|
|||||||
"output_hidden_states": True,
|
"output_hidden_states": True,
|
||||||
"output_attentions": self.has_attentions,
|
"output_attentions": self.has_attentions,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
@@ -1196,12 +1211,6 @@ class GenerationTesterMixin:
|
|||||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# Some models don't support the cache and returning past_key_values
|
|
||||||
if not hasattr(config, "use_cache"):
|
|
||||||
config.use_cache = False
|
|
||||||
else:
|
|
||||||
config.use_cache = True
|
|
||||||
|
|
||||||
# Encoder-decoder models are not supported
|
# Encoder-decoder models are not supported
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||||
@@ -1224,11 +1233,12 @@ class GenerationTesterMixin:
|
|||||||
"output_hidden_states": True,
|
"output_hidden_states": True,
|
||||||
"output_attentions": self.has_attentions,
|
"output_attentions": self.has_attentions,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
|
"use_cache": hasattr(config, "use_cache"), # Some models don't support the cache
|
||||||
}
|
}
|
||||||
generation_kwargs.update({"dola_layers": "low"})
|
generation_kwargs.update({"dola_layers": "low"})
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
||||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
|
self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache"))
|
||||||
|
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||||
@@ -1261,7 +1271,6 @@ class GenerationTesterMixin:
|
|||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason="This model doesn't support caching")
|
self.skipTest(reason="This model doesn't support caching")
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
# Sets assisted generation arguments such that:
|
# Sets assisted generation arguments such that:
|
||||||
@@ -1284,6 +1293,7 @@ class GenerationTesterMixin:
|
|||||||
"output_hidden_states": True,
|
"output_hidden_states": True,
|
||||||
"output_attentions": self.has_attentions,
|
"output_attentions": self.has_attentions,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
|
||||||
@@ -1566,7 +1576,6 @@ class GenerationTesterMixin:
|
|||||||
# 3. ignore `token_type_ids` for simplicity
|
# 3. ignore `token_type_ids` for simplicity
|
||||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||||
# active by default on some models
|
# active by default on some models
|
||||||
config.use_cache = True
|
|
||||||
if "token_type_ids" in inputs:
|
if "token_type_ids" in inputs:
|
||||||
del inputs["token_type_ids"]
|
del inputs["token_type_ids"]
|
||||||
|
|
||||||
@@ -1574,6 +1583,7 @@ class GenerationTesterMixin:
|
|||||||
model.eval()
|
model.eval()
|
||||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||||
model.generation_config.forced_eos_token_id = None
|
model.generation_config.forced_eos_token_id = None
|
||||||
|
model.generation_config.use_cache = True
|
||||||
|
|
||||||
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
@@ -1631,7 +1641,6 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="This model does not support the new cache format")
|
self.skipTest(reason="This model does not support the new cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = True
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
@@ -1640,6 +1649,7 @@ class GenerationTesterMixin:
|
|||||||
"num_beams": num_beams,
|
"num_beams": num_beams,
|
||||||
"num_return_sequences": num_beams,
|
"num_return_sequences": num_beams,
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sets seed before calling `generate` for the case with do_sample=True
|
# Sets seed before calling `generate` for the case with do_sample=True
|
||||||
@@ -1701,7 +1711,6 @@ class GenerationTesterMixin:
|
|||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache")
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
max_new_tokens = 20
|
max_new_tokens = 20
|
||||||
@@ -1712,6 +1721,7 @@ class GenerationTesterMixin:
|
|||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"cache_implementation": "static",
|
"cache_implementation": "static",
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
max_cache_len = seq_length + max_new_tokens
|
max_cache_len = seq_length + max_new_tokens
|
||||||
@@ -1740,7 +1750,6 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="This model does not support the quantized cache format")
|
self.skipTest(reason="This model does not support the quantized cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@@ -1750,6 +1759,7 @@ class GenerationTesterMixin:
|
|||||||
# careful with group size, should be divisor of model's hidden size
|
# careful with group size, should be divisor of model's hidden size
|
||||||
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
|
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
|
||||||
"return_dict_in_generate": True, # Required to return `past_key_values`
|
"return_dict_in_generate": True, # Required to return `past_key_values`
|
||||||
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
@@ -1890,14 +1900,14 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# Past Key Value States -- a few notes here:
|
# Past Key Value States -- a few notes here:
|
||||||
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
|
||||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
# 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the
|
||||||
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
|
# standard cache format (e.g.gptbigcode )
|
||||||
# complete
|
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet")
|
||||||
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
|
|
||||||
has_standard_cache = not any(
|
has_standard_cache = not any(
|
||||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||||
)
|
)
|
||||||
if use_cache and has_standard_cache:
|
if has_standard_cache:
|
||||||
|
if use_cache:
|
||||||
past_key_values = output.past_key_values
|
past_key_values = output.past_key_values
|
||||||
past_sequence_length = output.sequences.shape[-1] - 1
|
past_sequence_length = output.sequences.shape[-1] - 1
|
||||||
self._check_past_key_values_for_generate(
|
self._check_past_key_values_for_generate(
|
||||||
@@ -1906,6 +1916,8 @@ class GenerationTesterMixin:
|
|||||||
seq_length=past_sequence_length,
|
seq_length=past_sequence_length,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
elif use_cache is False:
|
||||||
|
self.assertTrue(output.past_key_values is None)
|
||||||
|
|
||||||
def _check_scores(self, batch_size, scores, length, config):
|
def _check_scores(self, batch_size, scores, length, config):
|
||||||
expected_shape = (batch_size, config.vocab_size)
|
expected_shape = (batch_size, config.vocab_size)
|
||||||
|
|||||||
Reference in New Issue
Block a user