[generate] return Cache object even if passed in a legacy format (#35673)

* generate returns a Cache object by default

* fix tests

* fix test for encoder-decoder models
This commit is contained in:
Joao Gante
2025-01-16 17:06:24 +00:00
committed by GitHub
parent 2818307e93
commit 94af1c0aa2
9 changed files with 36 additions and 156 deletions

View File

@@ -18,7 +18,6 @@ import gc
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, JetMoeConfig, is_torch_available
from transformers.testing_utils import (
@@ -299,10 +298,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
test_disk_offload_bin = False
test_disk_offload_safetensors = False
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def setUp(self):
self.model_tester = JetMoeModelTester(self)
self.config_tester = ConfigTester(