[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -367,9 +367,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
def _check_attentions_for_generate(self, *args, **kwargs):
|
||||
return True # Model does not return attention
|
||||
|
||||
@unittest.skip(reason="Past key values are not returned")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
@@ -382,9 +379,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
def test_model_parallel_beam_search(self):
|
||||
pass
|
||||
|
||||
def _check_past_key_values_for_generate(self, *args, **kwargs):
|
||||
return True
|
||||
|
||||
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
@@ -397,25 +391,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx if not use_cache else 1
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="TODO @arthurzucker not super important and failing.")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user