Cache: models return input cache type (#30716)

This commit is contained in:
Joao Gante
2024-05-08 18:26:34 +01:00
committed by GitHub
parent 71c1985069
commit f26e407370
11 changed files with 30 additions and 70 deletions

View File

@@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """
import unittest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("Recurrent gemma does not use legacy cache")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_save_load_fast_init_from_base(self):
pass