[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -72,7 +72,14 @@ if is_torch_available():
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
||||
from transformers.cache_utils import (
|
||||
Cache,
|
||||
DynamicCache,
|
||||
EncoderDecoderCache,
|
||||
HybridCache,
|
||||
QuantoQuantizedCache,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@@ -473,6 +480,8 @@ class GenerationTesterMixin:
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._greedy_generate(
|
||||
@@ -499,12 +508,14 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(output_generate, model.config)
|
||||
self._check_generate_outputs(output_generate, model.config)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
@@ -531,7 +542,7 @@ class GenerationTesterMixin:
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
)
|
||||
|
||||
self._check_outputs(output_generate, model.config, use_cache=True)
|
||||
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_sample_generate(self):
|
||||
@@ -550,6 +561,8 @@ class GenerationTesterMixin:
|
||||
def test_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._sample_generate(
|
||||
@@ -577,7 +590,7 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(output_generate, model.config, num_return_sequences=2)
|
||||
self._check_generate_outputs(output_generate, model.config, num_return_sequences=2)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_beam_search_generate(self):
|
||||
@@ -598,6 +611,8 @@ class GenerationTesterMixin:
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -625,7 +640,7 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
self._check_generate_outputs(
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
@@ -642,6 +657,8 @@ class GenerationTesterMixin:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
@@ -666,7 +683,7 @@ class GenerationTesterMixin:
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
)
|
||||
|
||||
self._check_outputs(
|
||||
self._check_generate_outputs(
|
||||
output_generate,
|
||||
model.config,
|
||||
use_cache=True,
|
||||
@@ -721,6 +738,8 @@ class GenerationTesterMixin:
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
@@ -750,7 +769,7 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
self._check_generate_outputs(
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
@@ -813,6 +832,8 @@ class GenerationTesterMixin:
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
@@ -840,7 +861,7 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
self._check_generate_outputs(
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
@@ -909,6 +930,8 @@ class GenerationTesterMixin:
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
@@ -947,7 +970,7 @@ class GenerationTesterMixin:
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
self._check_generate_outputs(
|
||||
output_generate,
|
||||
model.config,
|
||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||
@@ -999,6 +1022,8 @@ class GenerationTesterMixin:
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.is_decoder = True
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._contrastive_generate(
|
||||
@@ -1019,7 +1044,7 @@ class GenerationTesterMixin:
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
)
|
||||
|
||||
self._check_outputs(output_generate, model.config, use_cache=True)
|
||||
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
@@ -1205,7 +1230,7 @@ class GenerationTesterMixin:
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||
for output in (output_greedy, output_assisted):
|
||||
self._check_outputs(output, model.config, use_cache=True)
|
||||
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
@@ -1270,7 +1295,7 @@ class GenerationTesterMixin:
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
|
||||
for output in (output_greedy, output_prompt_lookup):
|
||||
self._check_outputs(output, model.config, use_cache=True)
|
||||
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_dola_decoding_sample(self):
|
||||
@@ -1320,7 +1345,7 @@ class GenerationTesterMixin:
|
||||
"dola_layers": "low",
|
||||
}
|
||||
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
|
||||
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
||||
self._check_generate_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_assisted_decoding_sample(self):
|
||||
@@ -1381,7 +1406,7 @@ class GenerationTesterMixin:
|
||||
}
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
self._check_outputs(output_assisted, config, use_cache=True)
|
||||
self._check_generate_outputs(output_assisted, config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||
@@ -1419,6 +1444,8 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
text_config = config.get_text_config()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
# We want to test only encoder-decoder models
|
||||
if not text_config.is_encoder_decoder:
|
||||
@@ -1765,7 +1792,7 @@ class GenerationTesterMixin:
|
||||
num_hidden_layers = text_config.num_hidden_layers
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
max_cache_len += inputs_embeds.shape[1]
|
||||
max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
|
||||
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
|
||||
|
||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
||||
@@ -2003,7 +2030,7 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
# Check 1: The cache shapes must match the expected shapes
|
||||
max_cache_len = seq_length + max_new_tokens
|
||||
max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache
|
||||
text_config = config.text_config if hasattr(config, "text_config") else config
|
||||
head_dim = (
|
||||
text_config.head_dim
|
||||
@@ -2138,6 +2165,58 @@ class GenerationTesterMixin:
|
||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
"""
|
||||
Tests that all optional outputs are behaving as expected when compilation is triggered.
|
||||
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# compilation-specific setup
|
||||
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||
if not has_defined_cache_implementation:
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||
output_generate = model.generate(
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
min_new_tokens=self.max_new_tokens,
|
||||
output_attentions=True,
|
||||
output_hidden_states=True,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=True,
|
||||
**logits_processor_kwargs,
|
||||
**inputs_dict,
|
||||
)
|
||||
|
||||
# Sanity check: compilation has happened
|
||||
self.assertTrue(hasattr(model, "_compiled_call"))
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(
|
||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||
)
|
||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||
|
||||
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_methods_with_logits_to_keep(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -2290,86 +2369,90 @@ class GenerationTesterMixin:
|
||||
# check whether we still need the overwrites
|
||||
self._test_attention_implementation("flash_attention_2")
|
||||
|
||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
||||
internal_batch_size = (
|
||||
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
||||
)
|
||||
|
||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
||||
prompt_length = getattr(self.model_tester, "seq_length", None)
|
||||
prompt_length = getattr(self.model_tester, "encoder_seq_length", prompt_length)
|
||||
prompt_length = getattr(self.model_tester, "text_seq_length", prompt_length)
|
||||
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
generated_length = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
|
||||
)
|
||||
decoder_past_key_values = getattr(output, "past_key_values", None)
|
||||
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
|
||||
decoder_past_key_values = decoder_past_key_values.self_attention_cache
|
||||
|
||||
# in some models we subsample the sequence length in inner layers
|
||||
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
|
||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
prompt_length = self.model_tester.get_subsampled_output_lengths(prompt_length)
|
||||
|
||||
# scores
|
||||
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
||||
self._check_scores(
|
||||
batch_size=internal_batch_size, scores=output.scores, generated_length=generated_length, config=config
|
||||
)
|
||||
|
||||
# unprocessed logits
|
||||
self._check_logits(internal_batch_size, output.logits, config=config)
|
||||
self._check_logits(batch_size=internal_batch_size, logits=output.logits, config=config)
|
||||
|
||||
# Attentions
|
||||
if self.has_attentions:
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, input_batch_size, config, seq_length
|
||||
attentions=output.encoder_attentions,
|
||||
batch_size=input_batch_size,
|
||||
config=config,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
internal_batch_size,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
batch_size=internal_batch_size,
|
||||
attentions=output.decoder_attentions,
|
||||
prompt_length=1, # the BOS token
|
||||
output_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_attentions_for_generate(
|
||||
internal_batch_size,
|
||||
attentions=attentions,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
batch_size=internal_batch_size,
|
||||
attentions=output.attentions,
|
||||
prompt_length=prompt_length,
|
||||
output_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, input_batch_size, config, seq_length
|
||||
hidden_states=output.encoder_hidden_states,
|
||||
batch_size=input_batch_size,
|
||||
config=config,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
internal_batch_size,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
batch_size=internal_batch_size,
|
||||
hidden_states=output.decoder_hidden_states,
|
||||
prompt_length=1, # the BOS token
|
||||
output_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_hidden_states_for_generate(
|
||||
internal_batch_size,
|
||||
hidden_states,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
batch_size=internal_batch_size,
|
||||
hidden_states=output.hidden_states,
|
||||
prompt_length=prompt_length,
|
||||
output_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
@@ -2396,59 +2479,73 @@ class GenerationTesterMixin:
|
||||
)
|
||||
if has_standard_cache:
|
||||
if use_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
cache_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
internal_batch_size,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
batch_size=internal_batch_size,
|
||||
decoder_past_key_values=decoder_past_key_values,
|
||||
cache_length=cache_length,
|
||||
config=config,
|
||||
)
|
||||
elif use_cache is False:
|
||||
self.assertTrue(output.past_key_values is None)
|
||||
self.assertTrue(decoder_past_key_values is None)
|
||||
|
||||
def _check_scores(self, batch_size, scores, length, config):
|
||||
def _check_scores(self, batch_size, scores, generated_length, config):
|
||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||
expected_shape = (batch_size, vocab_size)
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertEqual(len(scores), length)
|
||||
self.assertEqual(len(scores), generated_length)
|
||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||
|
||||
def _check_logits(self, batch_size, scores, config):
|
||||
def _check_logits(self, batch_size, logits, config):
|
||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
|
||||
self.assertIsInstance(logits, tuple)
|
||||
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
|
||||
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
||||
vocab_diff = vocab_size - scores[0].shape[-1]
|
||||
vocab_diff = vocab_size - logits[0].shape[-1]
|
||||
self.assertTrue(vocab_diff in [0, 1])
|
||||
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
||||
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
src_len = min_length + idx
|
||||
use_cache = decoder_past_key_values is not None
|
||||
has_static_cache = isinstance(decoder_past_key_values, (StaticCache, HybridCache))
|
||||
|
||||
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
|
||||
# token(s)
|
||||
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
|
||||
# elaborate checks
|
||||
for generated_length, iter_attentions in enumerate(attentions):
|
||||
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||
if use_cache and generated_length > 0:
|
||||
model_input_length = 1
|
||||
else:
|
||||
model_input_length = prompt_length + generated_length
|
||||
query_length = (
|
||||
prompt_length + generated_length
|
||||
if not has_static_cache
|
||||
else decoder_past_key_values.get_max_cache_shape()
|
||||
)
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
model_input_length,
|
||||
query_length,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in attentions],
|
||||
@@ -2456,71 +2553,75 @@ class GenerationTesterMixin:
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx if not use_cache else 1
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
|
||||
# new token(s)
|
||||
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
|
||||
# elaborate checks
|
||||
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
||||
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||
if use_cache and generated_length > 0:
|
||||
model_input_length = 1
|
||||
else:
|
||||
model_input_length = prompt_length + generated_length
|
||||
expected_shape = (batch_size, model_input_length, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
||||
encoder_expected_shape = (batch_size, prompt_length, config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||
self.assertIsInstance(past_key_values, (tuple, Cache))
|
||||
|
||||
# Encoder-decoder models: pull and verify the decoder cache
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
past_key_values = past_key_values.self_attention_cache
|
||||
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||
seq_length,
|
||||
cache_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
|
||||
if isinstance(past_key_values, Cache):
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in past_key_values.key_cache],
|
||||
[expected_shape] * len(past_key_values.key_cache),
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in past_key_values.value_cache],
|
||||
[expected_shape] * len(past_key_values.value_cache),
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
)
|
||||
|
||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||
else:
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
||||
[True] * len(past_key_values),
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
|
||||
[True] * len(decoder_past_key_values),
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
|
||||
[expected_shape] * len(decoder_past_key_values),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
||||
[expected_shape] * len(past_key_values),
|
||||
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
|
||||
[expected_shape] * len(decoder_past_key_values),
|
||||
)
|
||||
|
||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||
|
||||
Reference in New Issue
Block a user