[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -116,6 +116,16 @@ if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module from accelerate.hooks import AlignDevicesHook, add_hook_to_module
# Variable names used to hold the cache at generation time
ALL_CACHE_NAMES = [
"past_key_values", # default
"cache_params", # mamba-based models
"state", # rwkv
"mems", # xlnet
"past_buckets_states", # reformer
]
@dataclass @dataclass
class GenerateDecoderOnlyOutput(ModelOutput): class GenerateDecoderOnlyOutput(ModelOutput):
""" """
@@ -756,21 +766,6 @@ class GenerationMixin:
return input_ids, model_kwargs return input_ids, model_kwargs
def _extract_past_from_model_output(self, outputs: ModelOutput):
past_key_values = None
cache_name = "past_key_values"
if "past_key_values" in outputs:
past_key_values = outputs.past_key_values
elif "mems" in outputs:
past_key_values = outputs.mems
elif "past_buckets_states" in outputs:
past_key_values = outputs.past_buckets_states
elif "cache_params" in outputs:
past_key_values = outputs.cache_params
cache_name = "cache_params"
return cache_name, past_key_values
def _update_model_kwargs_for_generation( def _update_model_kwargs_for_generation(
self, self,
outputs: ModelOutput, outputs: ModelOutput,
@@ -779,10 +774,15 @@ class GenerationMixin:
num_new_tokens: int = 1, num_new_tokens: int = 1,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# update past_key_values keeping its naming used in model code # update past_key_values keeping its naming used in model code
cache_name, cache = self._extract_past_from_model_output(outputs) for possible_cache_name in ALL_CACHE_NAMES:
model_kwargs[cache_name] = cache if possible_cache_name in outputs:
if getattr(outputs, "state", None) is not None: # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
model_kwargs["state"] = outputs.state if possible_cache_name in ("past_buckets_states", "mems"):
cache_name = "past_key_values"
else:
cache_name = possible_cache_name
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
break
# update token_type_ids with last value # update token_type_ids with last value
if "token_type_ids" in model_kwargs: if "token_type_ids" in model_kwargs:
@@ -2087,7 +2087,7 @@ class GenerationMixin:
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
# - different models have a different cache name expected by the model (default = "past_key_values") # - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length # - `max_length`, prepared above, is used to determine the maximum cache length
max_cache_length = generation_config.max_length max_cache_length = generation_config.max_length - 1
if ( if (
inputs_tensor.shape[1] != input_ids_length inputs_tensor.shape[1] != input_ids_length
and model_input_name == "inputs_embeds" and model_input_name == "inputs_embeds"
@@ -2994,7 +2994,9 @@ class GenerationMixin:
next_past_key_values = selected_outputs["past_key_values"] next_past_key_values = selected_outputs["past_key_values"]
else: else:
_, next_past_key_values = self._extract_past_from_model_output(outputs) next_past_key_values = None
for possible_cache_name in ALL_CACHE_NAMES:
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
# Do it in-place layer per layer to save memory # Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache) or ( if isinstance(next_past_key_values, DynamicCache) or (
isinstance(next_past_key_values, EncoderDecoderCache) isinstance(next_past_key_values, EncoderDecoderCache)

View File

@@ -2557,32 +2557,3 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
return outputs return outputs
else: else:
return output_values return output_values
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
cache_name, cache = self._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
return model_kwargs

View File

@@ -16,7 +16,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@@ -1329,54 +1329,6 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
) )
return model_inputs return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
# update past_key_values keeping its naming used in model code
cache_name, cache = self._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update attention_mask
if getattr(outputs, "attention_mask", None) is not None:
model_kwargs["attention_mask"] = outputs.attention_mask
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
if model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
else:
past_positions = model_kwargs.pop("cache_position")
new_positions = torch.arange(
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
).to(past_positions.device)
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
return model_kwargs
def _reorder_cache(self, *args, **kwargs): def _reorder_cache(self, *args, **kwargs):
return self.language_model._reorder_cache(*args, **kwargs) return self.language_model._reorder_cache(*args, **kwargs)

View File

@@ -72,7 +72,14 @@ if is_torch_available():
SpeechEncoderDecoderModel, SpeechEncoderDecoderModel,
T5ForConditionalGeneration, T5ForConditionalGeneration,
) )
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.cache_utils import (
Cache,
DynamicCache,
EncoderDecoderCache,
HybridCache,
QuantoQuantizedCache,
StaticCache,
)
from transformers.generation import ( from transformers.generation import (
BeamSampleDecoderOnlyOutput, BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput, BeamSampleEncoderDecoderOutput,
@@ -473,6 +480,8 @@ class GenerationTesterMixin:
def test_greedy_generate_dict_outputs(self): def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_generate = self._greedy_generate( output_generate = self._greedy_generate(
@@ -499,12 +508,14 @@ class GenerationTesterMixin:
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self._check_outputs(output_generate, model.config) self._check_generate_outputs(output_generate, model.config)
@pytest.mark.generate @pytest.mark.generate
def test_greedy_generate_dict_outputs_use_cache(self): def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -531,7 +542,7 @@ class GenerationTesterMixin:
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
) )
self._check_outputs(output_generate, model.config, use_cache=True) self._check_generate_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate @pytest.mark.generate
def test_sample_generate(self): def test_sample_generate(self):
@@ -550,6 +561,8 @@ class GenerationTesterMixin:
def test_sample_generate_dict_output(self): def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_generate = self._sample_generate( output_generate = self._sample_generate(
@@ -577,7 +590,7 @@ class GenerationTesterMixin:
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
self._check_outputs(output_generate, model.config, num_return_sequences=2) self._check_generate_outputs(output_generate, model.config, num_return_sequences=2)
@pytest.mark.generate @pytest.mark.generate
def test_beam_search_generate(self): def test_beam_search_generate(self):
@@ -598,6 +611,8 @@ class GenerationTesterMixin:
def test_beam_search_generate_dict_output(self): def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs() beam_kwargs = self._get_beam_kwargs()
@@ -625,7 +640,7 @@ class GenerationTesterMixin:
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs( self._check_generate_outputs(
output_generate, output_generate,
model.config, model.config,
num_return_sequences=beam_kwargs["num_return_sequences"], num_return_sequences=beam_kwargs["num_return_sequences"],
@@ -642,6 +657,8 @@ class GenerationTesterMixin:
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs() beam_kwargs = self._get_beam_kwargs()
@@ -666,7 +683,7 @@ class GenerationTesterMixin:
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
) )
self._check_outputs( self._check_generate_outputs(
output_generate, output_generate,
model.config, model.config,
use_cache=True, use_cache=True,
@@ -721,6 +738,8 @@ class GenerationTesterMixin:
def test_beam_sample_generate_dict_output(self): def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs() beam_kwargs = self._get_beam_kwargs()
@@ -750,7 +769,7 @@ class GenerationTesterMixin:
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs( self._check_generate_outputs(
output_generate, output_generate,
model.config, model.config,
num_return_sequences=beam_kwargs["num_return_sequences"], num_return_sequences=beam_kwargs["num_return_sequences"],
@@ -813,6 +832,8 @@ class GenerationTesterMixin:
def test_group_beam_search_generate_dict_output(self): def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_diverse_beam_kwargs() beam_kwargs = self._get_diverse_beam_kwargs()
@@ -840,7 +861,7 @@ class GenerationTesterMixin:
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs( self._check_generate_outputs(
output_generate, output_generate,
model.config, model.config,
num_return_sequences=beam_kwargs["num_return_sequences"], num_return_sequences=beam_kwargs["num_return_sequences"],
@@ -909,6 +930,8 @@ class GenerationTesterMixin:
def test_constrained_beam_search_generate_dict_output(self): def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
@@ -947,7 +970,7 @@ class GenerationTesterMixin:
# Retrocompatibility check # Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs( self._check_generate_outputs(
output_generate, output_generate,
model.config, model.config,
num_return_sequences=beam_kwargs["num_return_sequences"], num_return_sequences=beam_kwargs["num_return_sequences"],
@@ -999,6 +1022,8 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching") self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True config.is_decoder = True
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_generate = self._contrastive_generate( output_generate = self._contrastive_generate(
@@ -1019,7 +1044,7 @@ class GenerationTesterMixin:
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
) )
self._check_outputs(output_generate, model.config, use_cache=True) self._check_generate_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate @pytest.mark.generate
def test_contrastive_generate_low_memory(self): def test_contrastive_generate_low_memory(self):
@@ -1205,7 +1230,7 @@ class GenerationTesterMixin:
# The two outputs must match and their shape must be as expected # The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(output_greedy, output_assisted) self._check_similar_generate_outputs(output_greedy, output_assisted)
for output in (output_greedy, output_assisted): for output in (output_greedy, output_assisted):
self._check_outputs(output, model.config, use_cache=True) self._check_generate_outputs(output, model.config, use_cache=True)
@pytest.mark.generate @pytest.mark.generate
def test_prompt_lookup_decoding_matches_greedy_search(self): def test_prompt_lookup_decoding_matches_greedy_search(self):
@@ -1270,7 +1295,7 @@ class GenerationTesterMixin:
# The two outputs must match and their shape must be as expected # The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup) self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
for output in (output_greedy, output_prompt_lookup): for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, model.config, use_cache=True) self._check_generate_outputs(output, model.config, use_cache=True)
@pytest.mark.generate @pytest.mark.generate
def test_dola_decoding_sample(self): def test_dola_decoding_sample(self):
@@ -1320,7 +1345,7 @@ class GenerationTesterMixin:
"dola_layers": "low", "dola_layers": "low",
} }
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict) output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False)) self._check_generate_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
@pytest.mark.generate @pytest.mark.generate
def test_assisted_decoding_sample(self): def test_assisted_decoding_sample(self):
@@ -1381,7 +1406,7 @@ class GenerationTesterMixin:
} }
output_assisted = model.generate(**generation_kwargs, **inputs_dict) output_assisted = model.generate(**generation_kwargs, **inputs_dict)
self._check_outputs(output_assisted, config, use_cache=True) self._check_generate_outputs(output_assisted, config, use_cache=True)
@pytest.mark.generate @pytest.mark.generate
def test_prompt_lookup_decoding_stops_at_eos(self): def test_prompt_lookup_decoding_stops_at_eos(self):
@@ -1419,6 +1444,8 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate() config, inputs_dict = self.prepare_config_and_inputs_for_generate()
text_config = config.get_text_config() text_config = config.get_text_config()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
# We want to test only encoder-decoder models # We want to test only encoder-decoder models
if not text_config.is_encoder_decoder: if not text_config.is_encoder_decoder:
@@ -1765,7 +1792,7 @@ class GenerationTesterMixin:
num_hidden_layers = text_config.num_hidden_layers num_hidden_layers = text_config.num_hidden_layers
inputs_embeds = model.get_input_embeddings()(input_ids) inputs_embeds = model.get_input_embeddings()(input_ids)
max_cache_len += inputs_embeds.shape[1] max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict) outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
# we should get `max_length` in shape, not `max_length - embeds_length` # we should get `max_length` in shape, not `max_length - embeds_length`
@@ -2003,7 +2030,7 @@ class GenerationTesterMixin:
) )
# Check 1: The cache shapes must match the expected shapes # Check 1: The cache shapes must match the expected shapes
max_cache_len = seq_length + max_new_tokens max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache
text_config = config.text_config if hasattr(config, "text_config") else config text_config = config.text_config if hasattr(config, "text_config") else config
head_dim = ( head_dim = (
text_config.head_dim text_config.head_dim
@@ -2138,6 +2165,58 @@ class GenerationTesterMixin:
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result) self._check_similar_generate_outputs(dynamic_result, compiled_result)
@pytest.mark.generate
def test_generate_compilation_all_outputs(self):
"""
Tests that all optional outputs are behaving as expected when compilation is triggered.
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
model = model_class(config).to(torch_device).eval()
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
if not has_defined_cache_implementation:
model.generation_config.cache_implementation = "static"
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
output_generate = model.generate(
do_sample=False,
num_beams=1,
max_new_tokens=self.max_new_tokens,
min_new_tokens=self.max_new_tokens,
output_attentions=True,
output_hidden_states=True,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
use_cache=True,
**logits_processor_kwargs,
**inputs_dict,
)
# Sanity check: compilation has happened
self.assertTrue(hasattr(model, "_compiled_call"))
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
else:
self.assertTrue(
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self._check_generate_outputs(output_generate, model.config, use_cache=True)
@pytest.mark.generate @pytest.mark.generate
def test_generate_methods_with_logits_to_keep(self): def test_generate_methods_with_logits_to_keep(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
@@ -2290,86 +2369,90 @@ class GenerationTesterMixin:
# check whether we still need the overwrites # check whether we still need the overwrites
self._test_attention_implementation("flash_attention_2") self._test_attention_implementation("flash_attention_2")
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
input_batch_size = int(output.sequences.shape[0] / num_return_sequences) input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = ( internal_batch_size = (
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
) )
seq_length = getattr(self.model_tester, "seq_length", None) prompt_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) prompt_length = getattr(self.model_tester, "encoder_seq_length", prompt_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length) prompt_length = getattr(self.model_tester, "text_seq_length", prompt_length)
config = config.text_config if hasattr(config, "text_config") else config config = config.text_config if hasattr(config, "text_config") else config
gen_len = ( generated_length = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
) )
decoder_past_key_values = getattr(output, "past_key_values", None)
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
decoder_past_key_values = decoder_past_key_values.self_attention_cache
# in some models we subsample the sequence length in inner layers # in some models we subsample the sequence length in inner layers
if hasattr(self.model_tester, "get_subsampled_output_lengths"): if hasattr(self.model_tester, "get_subsampled_output_lengths"):
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) prompt_length = self.model_tester.get_subsampled_output_lengths(prompt_length)
# scores # scores
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config) self._check_scores(
batch_size=internal_batch_size, scores=output.scores, generated_length=generated_length, config=config
)
# unprocessed logits # unprocessed logits
self._check_logits(internal_batch_size, output.logits, config=config) self._check_logits(batch_size=internal_batch_size, logits=output.logits, config=config)
# Attentions # Attentions
if self.has_attentions: if self.has_attentions:
if config.is_encoder_decoder: if config.is_encoder_decoder:
# encoder # encoder
self._check_encoder_attention_for_generate( self._check_encoder_attention_for_generate(
output.encoder_attentions, input_batch_size, config, seq_length attentions=output.encoder_attentions,
batch_size=input_batch_size,
config=config,
prompt_length=prompt_length,
) )
# decoder # decoder
self._check_attentions_for_generate( self._check_attentions_for_generate(
internal_batch_size, batch_size=internal_batch_size,
output.decoder_attentions, attentions=output.decoder_attentions,
min_length=1, prompt_length=1, # the BOS token
max_length=output.sequences.shape[-1], output_length=output.sequences.shape[-1],
config=config, config=config,
use_cache=use_cache, decoder_past_key_values=decoder_past_key_values,
) )
else: else:
# if use_cache first input is equal to no use_cache, so skip here
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate( self._check_attentions_for_generate(
internal_batch_size, batch_size=internal_batch_size,
attentions=attentions, attentions=output.attentions,
min_length=min_length, prompt_length=prompt_length,
max_length=output.sequences.shape[-1], output_length=output.sequences.shape[-1],
config=config, config=config,
use_cache=use_cache, decoder_past_key_values=decoder_past_key_values,
) )
# Hidden States # Hidden States
if config.is_encoder_decoder: if config.is_encoder_decoder:
# encoder # encoder
self._check_encoder_hidden_states_for_generate( self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, input_batch_size, config, seq_length hidden_states=output.encoder_hidden_states,
batch_size=input_batch_size,
config=config,
prompt_length=prompt_length,
) )
# decoder # decoder
self._check_hidden_states_for_generate( self._check_hidden_states_for_generate(
internal_batch_size, batch_size=internal_batch_size,
output.decoder_hidden_states, hidden_states=output.decoder_hidden_states,
min_length=1, prompt_length=1, # the BOS token
max_length=output.sequences.shape[-1], output_length=output.sequences.shape[-1],
config=config, config=config,
use_cache=use_cache, use_cache=use_cache,
) )
else: else:
# if use_cache first input is equal to no use_cache, so skip here
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate( self._check_hidden_states_for_generate(
internal_batch_size, batch_size=internal_batch_size,
hidden_states, hidden_states=output.hidden_states,
min_length=min_length, prompt_length=prompt_length,
max_length=output.sequences.shape[-1], output_length=output.sequences.shape[-1],
config=config, config=config,
use_cache=use_cache, use_cache=use_cache,
) )
@@ -2396,59 +2479,73 @@ class GenerationTesterMixin:
) )
if has_standard_cache: if has_standard_cache:
if use_cache: if use_cache:
past_key_values = output.past_key_values cache_length = output.sequences.shape[-1] - 1
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate( self._check_past_key_values_for_generate(
internal_batch_size, batch_size=internal_batch_size,
past_key_values, decoder_past_key_values=decoder_past_key_values,
seq_length=past_sequence_length, cache_length=cache_length,
config=config, config=config,
) )
elif use_cache is False: elif use_cache is False:
self.assertTrue(output.past_key_values is None) self.assertTrue(decoder_past_key_values is None)
def _check_scores(self, batch_size, scores, length, config): def _check_scores(self, batch_size, scores, generated_length, config):
vocab_size = config.get_text_config(decoder=True).vocab_size vocab_size = config.get_text_config(decoder=True).vocab_size
expected_shape = (batch_size, vocab_size) expected_shape = (batch_size, vocab_size)
self.assertIsInstance(scores, tuple) self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), length) self.assertEqual(len(scores), generated_length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def _check_logits(self, batch_size, scores, config): def _check_logits(self, batch_size, logits, config):
vocab_size = config.get_text_config(decoder=True).vocab_size vocab_size = config.get_text_config(decoder=True).vocab_size
self.assertIsInstance(scores, tuple) self.assertIsInstance(logits, tuple)
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores)) self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models) # vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = vocab_size - scores[0].shape[-1] vocab_diff = vocab_size - logits[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1]) self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores)) self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
) )
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) self.assertEqual(len(attentions), (output_length - prompt_length))
for idx, iter_attentions in enumerate(attentions): use_cache = decoder_past_key_values is not None
tgt_len = min_length + idx if not use_cache else 1 has_static_cache = isinstance(decoder_past_key_values, (StaticCache, HybridCache))
src_len = min_length + idx
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
# token(s)
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
# elaborate checks
for generated_length, iter_attentions in enumerate(attentions):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
query_length = (
prompt_length + generated_length
if not has_static_cache
else decoder_past_key_values.get_max_cache_shape()
)
expected_shape = ( expected_shape = (
batch_size * num_beam_groups, batch_size,
config.num_attention_heads, config.num_attention_heads,
tgt_len, model_input_length,
src_len, query_length,
) )
# check attn size # check attn size
self.assertListEqual( self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
) )
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length)
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions], [layer_attentions.shape for layer_attentions in attentions],
@@ -2456,71 +2553,75 @@ class GenerationTesterMixin:
) )
def _check_hidden_states_for_generate( def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
): ):
self.assertIsInstance(hidden_states, tuple) self.assertIsInstance(hidden_states, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states), [True] * len(hidden_states),
) )
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) self.assertEqual(len(hidden_states), (output_length - prompt_length))
for idx, iter_hidden_states in enumerate(hidden_states): # When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
seq_len = min_length + idx if not use_cache else 1 # new token(s)
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) # NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
# elaborate checks
for generated_length, iter_hidden_states in enumerate(hidden_states):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
expected_shape = (batch_size, model_input_length, config.hidden_size)
# check hidden size # check hidden size
self.assertListEqual( self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states), [expected_shape] * len(iter_hidden_states),
) )
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
encoder_expected_shape = (batch_size, seq_length, config.hidden_size) encoder_expected_shape = (batch_size, prompt_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple) self.assertIsInstance(hidden_states, tuple)
self.assertListEqual( self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states], [layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states), [encoder_expected_shape] * len(hidden_states),
) )
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(past_key_values, (tuple, Cache)) self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
# Encoder-decoder models: pull and verify the decoder cache
if isinstance(past_key_values, EncoderDecoderCache):
past_key_values = past_key_values.self_attention_cache
# (batch, head, seq_length, head_features) # (batch, head, seq_length, head_features)
expected_shape = ( expected_shape = (
batch_size * num_beam_groups, batch_size,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
seq_length, cache_length,
config.hidden_size // config.num_attention_heads, config.hidden_size // config.num_attention_heads,
) )
if isinstance(past_key_values, Cache): if isinstance(decoder_past_key_values, Cache):
self.assertListEqual( self.assertListEqual(
[key_tensor.shape for key_tensor in past_key_values.key_cache], [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(past_key_values.key_cache), [expected_shape] * len(decoder_past_key_values.key_cache),
) )
self.assertListEqual( self.assertListEqual(
[value_tensor.shape for value_tensor in past_key_values.value_cache], [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(past_key_values.value_cache), [expected_shape] * len(decoder_past_key_values.value_cache),
) )
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default # Legacy cache format checks. This branch should be removed when all models use `Cache` by default
else: else:
self.assertListEqual( self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values], [isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
[True] * len(past_key_values), [True] * len(decoder_past_key_values),
) )
# check shape key, value # check shape key, value
self.assertListEqual( self.assertListEqual(
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values], [layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
[expected_shape] * len(past_key_values), [expected_shape] * len(decoder_past_key_values),
) )
self.assertListEqual( self.assertListEqual(
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values], [layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
[expected_shape] * len(past_key_values), [expected_shape] * len(decoder_past_key_values),
) )
def _check_sequence_inside_sequence(self, tensor_1, tensor_2): def _check_sequence_inside_sequence(self, tensor_1, tensor_2):

View File

@@ -723,103 +723,12 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
self.assertIsNotNone(model) self.assertIsNotNone(model)
# overwrite because BLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format # overwrite because BLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
use_cache = True # force this to be True in case False is passed use_cache = True # force this to be True in case False is passed
super()._check_generate_outputs(
input_batch_size = int(output.sequences.shape[0] / num_return_sequences) output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
internal_batch_size = (
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
) )
seq_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
config = config.text_config if hasattr(config, "text_config") else config
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# in some models we subsample the sequence length in inner layers
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
# scores
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
# unprocessed logits
self._check_logits(internal_batch_size, output.logits, config=config)
# Attentions
if self.has_attentions:
if config.is_encoder_decoder:
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, input_batch_size, config, seq_length
)
# decoder
self._check_attentions_for_generate(
internal_batch_size,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
internal_batch_size,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
if config.is_encoder_decoder:
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, input_batch_size, config, seq_length
)
# decoder
self._check_hidden_states_for_generate(
internal_batch_size,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
internal_batch_size,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Past Key Value States
if use_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
internal_batch_size,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
# overwrite because BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present # overwrite because BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present
@pytest.mark.generate @pytest.mark.generate
def test_left_padding_compatibility(self): def test_left_padding_compatibility(self):

View File

@@ -20,7 +20,7 @@ from packaging import version
from parameterized import parameterized from parameterized import parameterized
from pytest import mark from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, HybridCache, is_torch_available, pipeline from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_torch_available, pipeline
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
require_flash_attn, require_flash_attn,
@@ -135,51 +135,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_generate_continue_from_inputs_embeds(self): def test_generate_continue_from_inputs_embeds(self):
pass pass
# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx if not use_cache else max_length
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
# overwrite because HybridCache has fixed length for key/values
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, HybridCache)
# check shape key, value (batch, head, max_seq_length, head_features)
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
# we should get `max_length` in shape, not `max_length - embeds_length`
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different") @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self): def test_sdpa_equivalence(self):
pass pass

View File

@@ -20,7 +20,7 @@ from packaging import version
from parameterized import parameterized from parameterized import parameterized
from pytest import mark from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
require_flash_attn, require_flash_attn,
@@ -150,51 +150,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_generate_continue_from_inputs_embeds(self): def test_generate_continue_from_inputs_embeds(self):
pass pass
# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx if not use_cache else max_length
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
# overwrite because HybridCache has fixed length for key/values
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, HybridCache)
# check shape key, value (batch, head, max_seq_length, head_features)
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
# we should get `max_length` in shape, not `max_length - embeds_length`
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self): def test_sdpa_equivalence(self):
pass pass

View File

@@ -456,51 +456,26 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
# GIT attention shape depends on image inputs, overwrite # GIT attention shape depends on image inputs, overwrite
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1) image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
prompt_length += image_length
for idx, iter_attentions in enumerate(attentions): output_length += image_length
tgt_len = min_length + idx + image_length if not use_cache else 1 super()._check_attentions_for_generate(
src_len = min_length + idx + image_length batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
)
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate( def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
): ):
# GIT attention shape depends on image inputs, overwrite # GIT attention shape depends on image inputs, overwrite
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1) image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
prompt_length += image_length
for idx, iter_hidden_states in enumerate(hidden_states): output_length += image_length
seq_len = min_length + idx + image_length if not use_cache else 1 super()._check_hidden_states_for_generate(
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) batch_size, hidden_states, prompt_length, output_length, config, use_cache=use_cache
# check hidden size )
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):

View File

@@ -815,7 +815,7 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
) )
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
""" """
Overwrite from generation tests because Idefics has only SDPA layers. Overwrite from generation tests because Idefics has only SDPA layers.

View File

@@ -251,10 +251,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
return inputs_dict return inputs_dict
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings # we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings
def _check_scores(self, batch_size, scores, length, config): def _check_scores(self, batch_size, scores, generated_length, config):
expected_shape = (batch_size, config.vocab_size - 1) expected_shape = (batch_size, config.vocab_size - 1)
self.assertIsInstance(scores, tuple) self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), length) self.assertEqual(len(scores), generated_length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
@run_test_using_subprocess @run_test_using_subprocess

View File

@@ -565,103 +565,12 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
self.assertIsNotNone(model) self.assertIsNotNone(model)
# overwrite because InstructBLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format # overwrite because InstructBLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
use_cache = True # force this to be True in case False is passed use_cache = True # force this to be True in case False is passed
super()._check_generate_outputs(
input_batch_size = int(output.sequences.shape[0] / num_return_sequences) output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
internal_batch_size = (
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
) )
seq_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
config = config.text_config if hasattr(config, "text_config") else config
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# in some models we subsample the sequence length in inner layers
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
# scores
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
# unprocessed logits
self._check_logits(internal_batch_size, output.logits, config=config)
# Attentions
if self.has_attentions:
if config.is_encoder_decoder:
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, input_batch_size, config, seq_length
)
# decoder
self._check_attentions_for_generate(
internal_batch_size,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
internal_batch_size,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
if config.is_encoder_decoder:
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, input_batch_size, config, seq_length
)
# decoder
self._check_hidden_states_for_generate(
internal_batch_size,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
internal_batch_size,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Past Key Value States
if use_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
internal_batch_size,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
# overwrite because InstructBLIP cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present # overwrite because InstructBLIP cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
@pytest.mark.generate @pytest.mark.generate
def test_left_padding_compatibility(self): def test_left_padding_compatibility(self):

View File

@@ -581,103 +581,12 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
self.assertIsNotNone(model) self.assertIsNotNone(model)
# overwrite because InstructBLIPVideo internally calls LM.generate() with embeds thus it cannot operate in no cache format # overwrite because InstructBLIPVideo internally calls LM.generate() with embeds thus it cannot operate in no cache format
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
use_cache = True # force this to be True in case False is passed use_cache = True # force this to be True in case False is passed
super()._check_generate_outputs(
input_batch_size = int(output.sequences.shape[0] / num_return_sequences) output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
internal_batch_size = (
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
) )
seq_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
config = config.text_config if hasattr(config, "text_config") else config
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# in some models we subsample the sequence length in inner layers
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
# scores
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
# unprocessed logits
self._check_logits(internal_batch_size, output.logits, config=config)
# Attentions
if self.has_attentions:
if config.is_encoder_decoder:
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, input_batch_size, config, seq_length
)
# decoder
self._check_attentions_for_generate(
internal_batch_size,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
internal_batch_size,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
if config.is_encoder_decoder:
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, input_batch_size, config, seq_length
)
# decoder
self._check_hidden_states_for_generate(
internal_batch_size,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
internal_batch_size,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Past Key Value States
if use_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
internal_batch_size,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
# overwrite because InstructBLIPVideo cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present # overwrite because InstructBLIPVideo cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
@pytest.mark.generate @pytest.mark.generate
def test_left_padding_compatibility(self): def test_left_padding_compatibility(self):

View File

@@ -468,12 +468,12 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
], ],
) )
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
# overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape # overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape
encoder_expected_shape = ( encoder_expected_shape = (
batch_size, batch_size,
config.num_attention_heads, config.num_attention_heads,
seq_length, prompt_length,
self.model_tester.attention_window // 2 * 2 + 1, self.model_tester.attention_window // 2 * 2 + 1,
) )
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)

View File

@@ -785,7 +785,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
[self.model_tester.num_attention_heads, block_len, 3 * block_len], [self.model_tester.num_attention_heads, block_len, 3 * block_len],
) )
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
block_len = getattr(self.model_tester, "block_len", None) block_len = getattr(self.model_tester, "block_len", None)
encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len) encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len)
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
@@ -920,10 +920,10 @@ class LongT5TGlobalModelTest(LongT5ModelTest):
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
) )
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
block_len = getattr(self.model_tester, "block_len", None) block_len = getattr(self.model_tester, "block_len", None)
global_block_size = getattr(self.model_tester, "global_block_size", None) global_block_size = getattr(self.model_tester, "global_block_size", None)
global_seq_length = seq_length // global_block_size global_seq_length = prompt_length // global_block_size
encoder_expected_shape = ( encoder_expected_shape = (
batch_size, batch_size,
2, 2,

View File

@@ -323,32 +323,37 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
torch.testing.assert_close(out_embeds, out_ids) torch.testing.assert_close(out_embeds, out_ids)
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
# Mllama has cross attention layers and those have a different shape than normal attention layers # Mllama has cross attention layers and those have a different shape than normal attention layers
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
) )
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) self.assertEqual(len(attentions), (output_length - prompt_length))
cross_attention_layers = self.model_tester.text_config["cross_attention_layers"] cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]
use_cache = decoder_past_key_values is not None
for idx, iter_attentions in enumerate(attentions): for generated_length, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1 # regardless of using cache, the first forward pass will have the full prompt as input
src_len = min_length + idx if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
query_length = prompt_length + generated_length
expected_shape = ( expected_shape = (
batch_size * num_beam_groups, batch_size,
config.num_attention_heads, config.num_attention_heads,
tgt_len, model_input_length,
src_len, query_length,
) )
expected_shape_cross = ( expected_shape_cross = (
batch_size * num_beam_groups, batch_size,
config.num_attention_heads, config.num_attention_heads,
tgt_len, model_input_length,
self.model_tester.image_length, self.model_tester.image_length,
) )

View File

@@ -575,77 +575,12 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
return config, filtered_inputs_dict return config, filtered_inputs_dict
def _check_hidden_states_for_generate( def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
self.assertIsInstance(hidden_states, tuple) super()._check_generate_outputs(
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length if idx == 0 else 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
super()._check_outputs(
output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
) )
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = 1
src_len = min_length + idx
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -399,11 +399,11 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
], ],
) )
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
encoder_expected_shape = ( encoder_expected_shape = (
batch_size, batch_size,
config.num_attention_heads, config.num_attention_heads,
math.ceil(seq_length / config.block_size), math.ceil(prompt_length / config.block_size),
config.block_size, config.block_size,
config.block_size + config.num_global_tokens, config.block_size + config.num_global_tokens,
) )
@@ -413,8 +413,8 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
[encoder_expected_shape] * len(attentions), [encoder_expected_shape] * len(attentions),
) )
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
encoder_expected_shape = (batch_size, self.round_up(seq_length, config.block_size), config.hidden_size) encoder_expected_shape = (batch_size, self.round_up(prompt_length, config.block_size), config.hidden_size)
self.assertIsInstance(hidden_states, tuple) self.assertIsInstance(hidden_states, tuple)
# Only the last layer will have the hidden states truncated back to token level # Only the last layer will have the hidden states truncated back to token level
self.assertListEqual( self.assertListEqual(
@@ -424,7 +424,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
# Only the last layer will have the hidden states truncated back to token level # Only the last layer will have the hidden states truncated back to token level
self.assertEqual( self.assertEqual(
hidden_states[-1][0].shape, hidden_states[-1][0].shape,
(batch_size, seq_length, config.hidden_size), (batch_size, prompt_length, config.hidden_size),
) )
def test_hidden_states_output(self): def test_hidden_states_output(self):

View File

@@ -753,20 +753,20 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name) text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
# overwrite because # pix2struct seq length depends on image inputs # overwrite because # pix2struct seq length depends on image inputs
seq_length = self.model_tester.max_patches prompt_length = self.model_tester.max_patches
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length)
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions], [layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions), [encoder_expected_shape] * len(attentions),
) )
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
# overwrite because # pix2struct seq length depends on image inputs # overwrite because # pix2struct seq length depends on image inputs
seq_length = self.model_tester.max_patches prompt_length = self.model_tester.max_patches
encoder_expected_shape = (batch_size, seq_length, config.hidden_size) encoder_expected_shape = (batch_size, prompt_length, config.hidden_size)
self.assertIsInstance(hidden_states, tuple) self.assertIsInstance(hidden_states, tuple)
self.assertListEqual( self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states], [layer_hidden_states.shape for layer_hidden_states in hidden_states],

View File

@@ -367,9 +367,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_training_gradient_checkpointing_use_reentrant_false(self): def test_training_gradient_checkpointing_use_reentrant_false(self):
pass pass
def _check_attentions_for_generate(self, *args, **kwargs):
return True # Model does not return attention
@unittest.skip(reason="Past key values are not returned") @unittest.skip(reason="Past key values are not returned")
def test_prompt_lookup_decoding_matches_greedy_search(self): def test_prompt_lookup_decoding_matches_greedy_search(self):
pass pass
@@ -382,9 +379,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_model_parallel_beam_search(self): def test_model_parallel_beam_search(self):
pass pass
def _check_past_key_values_for_generate(self, *args, **kwargs):
return True
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported") @unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self): def test_assisted_decoding_matches_greedy_search(self):
pass pass
@@ -397,25 +391,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_assisted_decoding_sample(self): def test_assisted_decoding_sample(self):
pass pass
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx if not use_cache else 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@unittest.skip(reason="TODO @arthurzucker not super important and failing.") @unittest.skip(reason="TODO @arthurzucker not super important and failing.")
def test_initialization(self): def test_initialization(self):
pass pass

View File

@@ -620,36 +620,42 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
self.assertIsNotNone(model) self.assertIsNotNone(model)
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
# NOTE (joao): this function is substancially different from the original, the attention has different
# *number* of shapes in certain conditions
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions) [isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
) )
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) self.assertEqual(len(attentions), (output_length - prompt_length))
for idx, iter_attentions in enumerate(attentions): for generated_length, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1 use_cache = decoder_past_key_values is not None and generated_length > 0
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
tgt_chunk_len = config.local_attn_chunk_length model_input_length = prompt_length + generated_length if not use_cache else 1
src_chunk_len = config.local_attn_chunk_length * ( num_chunks = model_input_length // config.local_attn_chunk_length + (
model_input_length % config.local_attn_chunk_length != 0
)
model_input_chunk_len = config.local_attn_chunk_length
query_chunk_len = config.local_attn_chunk_length * (
1 + config.local_num_chunks_after + config.local_num_chunks_before 1 + config.local_num_chunks_after + config.local_num_chunks_before
) )
if use_cache: if use_cache:
expected_shape = ( expected_shape = (
batch_size * num_beam_groups, batch_size,
config.num_attention_heads, config.num_attention_heads,
tgt_len, model_input_length,
min_length // config.local_attn_chunk_length + 1 + idx, prompt_length // config.local_attn_chunk_length + generated_length,
) )
else: else:
expected_shape = ( expected_shape = (
batch_size * num_beam_groups, batch_size,
config.num_attention_heads, config.num_attention_heads,
num_chunks, num_chunks,
tgt_chunk_len, model_input_chunk_len,
src_chunk_len, query_chunk_len,
) )
# check attn size # check attn size
self.assertListEqual( self.assertListEqual(
@@ -657,25 +663,29 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
) )
def _check_hidden_states_for_generate( def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
): ):
# NOTE (joao): this function is substancially different from the original, the hidden states have different
# length in certain conditions
self.assertIsInstance(hidden_states, tuple) self.assertIsInstance(hidden_states, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states], [isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states), [True] * len(hidden_states),
) )
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) self.assertEqual(len(hidden_states), (output_length - prompt_length))
for idx, iter_hidden_states in enumerate(hidden_states): for generation_length, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx use_cache_this_iter = use_cache and generation_length > 0
seq_len = config.local_attn_chunk_length * ( model_input_length = prompt_length + generation_length
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0) model_output_length = config.local_attn_chunk_length * (
model_input_length // config.local_attn_chunk_length
+ (model_input_length % config.local_attn_chunk_length != 0)
) )
if use_cache: if use_cache_this_iter:
seq_len = 1 model_output_length = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) expected_shape = (batch_size, model_output_length, config.hidden_size)
# check hidden size # check hidden size
self.assertListEqual( self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
@@ -789,37 +799,42 @@ class ReformerLSHAttnModelTest(
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
# NOTE (joao): this function is substancially different from the original, the attention has different
# *number* of shapes in certain conditions
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions) [isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
) )
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) self.assertEqual(len(attentions), (output_length - prompt_length))
for idx, iter_attentions in enumerate(attentions): for generated_length, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1 use_cache = decoder_past_key_values is not None and generated_length > 0
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0) model_input_len = prompt_length + generated_length if not use_cache else 1
tgt_chunk_len = config.lsh_attn_chunk_length num_chunks = model_input_len // config.lsh_attn_chunk_length + (
src_chunk_len = config.lsh_attn_chunk_length * ( model_input_len % config.lsh_attn_chunk_length != 0
)
model_input_chunk_len = config.lsh_attn_chunk_length
query_chunk_len = config.lsh_attn_chunk_length * (
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before 1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
) )
if use_cache: if use_cache:
expected_shape = ( expected_shape = (
batch_size * num_beam_groups, batch_size,
config.num_attention_heads, config.num_attention_heads,
config.num_hashes, config.num_hashes,
tgt_len, model_input_len,
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before), config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
) )
else: else:
expected_shape = ( expected_shape = (
batch_size * num_beam_groups, batch_size,
config.num_attention_heads, config.num_attention_heads,
num_chunks * config.num_hashes, num_chunks * config.num_hashes,
tgt_chunk_len, model_input_chunk_len,
src_chunk_len, query_chunk_len,
) )
# check attn size # check attn size
self.assertListEqual( self.assertListEqual(
@@ -827,25 +842,29 @@ class ReformerLSHAttnModelTest(
) )
def _check_hidden_states_for_generate( def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
): ):
# NOTE (joao): this function is substancially different from the original, the hidden states have different
# length in certain conditions
self.assertIsInstance(hidden_states, tuple) self.assertIsInstance(hidden_states, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states], [isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states), [True] * len(hidden_states),
) )
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) self.assertEqual(len(hidden_states), (output_length - prompt_length))
for idx, iter_hidden_states in enumerate(hidden_states): for generation_length, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx if not use_cache else 1 use_cache_this_iter = use_cache and generation_length > 0
seq_len = config.lsh_attn_chunk_length * ( model_input_length = prompt_length + generation_length
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0) model_output_length = config.local_attn_chunk_length * (
model_input_length // config.local_attn_chunk_length
+ (model_input_length % config.local_attn_chunk_length != 0)
) )
if use_cache: if use_cache_this_iter:
seq_len = 1 model_output_length = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) expected_shape = (batch_size, model_output_length, config.hidden_size)
# check hidden size # check hidden size
self.assertListEqual( self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],

View File

@@ -416,48 +416,6 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
pass pass
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape[:2]
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# Attentions
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, batch_size, config, subsampled_seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is # overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
# `input_features` # `input_features`
def test_lm_head_model_random_no_beam_search_generate(self): def test_lm_head_model_random_no_beam_search_generate(self):

View File

@@ -527,48 +527,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
pass pass
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, mel, seq_length = input_ids.shape
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# Attentions
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, batch_size, config, subsampled_seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is # overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
# `input_features` # `input_features`
def test_lm_head_model_random_no_beam_search_generate(self): def test_lm_head_model_random_no_beam_search_generate(self):

View File

@@ -1607,6 +1607,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_compile_model_forward(self): def test_generate_compile_model_forward(self):
pass pass
# TODO (joao, eustache): fix me :)
@unittest.skip(reason="A CUDA exception is thrown when storing extra outputs")
def test_generate_compilation_all_outputs(self):
pass
@require_torch @require_torch
@require_torchaudio @require_torchaudio

View File

@@ -473,50 +473,24 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
self.assertIsInstance(attentions, tuple) # adds PAD dummy token, expected shape is off by 1
self.assertListEqual( prompt_length += 1
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) output_length += 1
super()._check_attentions_for_generate(
batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
) )
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
# adds PAD dummy token
tgt_len = min_length + idx + 1
src_len = min_length + idx + 1
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate( def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
): ):
self.assertIsInstance(hidden_states, tuple) # adds PAD dummy token, expected shape is off by 1
self.assertListEqual( prompt_length += 1
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], output_length += 1
[True] * len(hidden_states), super()._check_hidden_states_for_generate(
batch_size, hidden_states, prompt_length, output_length, config, use_cache
) )
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
# adds PAD dummy token
seq_len = min_length + idx + 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):

View File

@@ -636,57 +636,52 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
weight.data.fill_(3) weight.data.fill_(3)
def _check_hidden_states_for_generate( def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
): ):
self.assertIsInstance(hidden_states, tuple) self.assertIsInstance(hidden_states, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states), [True] * len(hidden_states),
) )
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) self.assertEqual(len(hidden_states), (output_length - prompt_length))
for idx, iter_hidden_states in enumerate(hidden_states): for generated_length, iter_hidden_states in enumerate(hidden_states):
# check hidden size # check hidden size
for i, layer_hidden_states in enumerate(iter_hidden_states): for i, layer_hidden_states in enumerate(iter_hidden_states):
# every 2nd tensor is from extra stream # every 2nd tensor is from extra stream
if i % 2 != 0: if i % 2 != 0:
seq_len = 1 model_output_length = 1
else: else:
# for first item dummy PAD token is appended so need one more # for first item dummy PAD token is appended so need one more
# else offset+dummy_token when using cache # else offset+dummy_token when using cache
seq_len = (min_length + 1) if idx == 0 else 3 model_output_length = (prompt_length + 1) if generated_length == 0 else 3
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) expected_shape = (batch_size, model_output_length, config.hidden_size)
self.assertEqual(layer_hidden_states.shape, expected_shape) self.assertEqual(layer_hidden_states.shape, expected_shape)
def _check_attentions_for_generate( def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
): ):
self.assertIsInstance(attentions, tuple) self.assertIsInstance(attentions, tuple)
self.assertListEqual( self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
) )
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) self.assertEqual(len(attentions), (output_length - prompt_length))
for idx, attentions_item in enumerate(attentions): for generated_length, attentions_item in enumerate(attentions):
for iter_attentions in attentions_item: for iter_attentions in attentions_item:
tgt_len = min_length model_input_length = prompt_length
# for first item dummy PAD token is appended so need one more # for first item dummy PAD token is appended so need one more
# every token after consists of offset+dummy_token length when using cache # every token after consists of offset+dummy_token length when using cache
if idx == 0: if generated_length == 0:
tgt_len += 1 model_input_length += 1
else: else:
tgt_len = 3 model_input_length = 3
src_len = min_length + idx + 1 query_length = prompt_length + generated_length + 1
expected_shape = ( expected_shape = (batch_size, config.num_attention_heads, model_input_length, query_length)
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size # check attn size
self.assertListEqual( self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [layer_attention.shape for layer_attention in iter_attentions],