[generate] return Cache object even if passed in a legacy format (#35673)

* generate returns a Cache object by default

* fix tests

* fix test for encoder-decoder models
This commit is contained in:
Joao Gante
2025-01-16 17:06:24 +00:00
committed by GitHub
parent 2818307e93
commit 94af1c0aa2
9 changed files with 36 additions and 156 deletions

View File

@@ -18,7 +18,6 @@ import inspect
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, BambaConfig, is_torch_available
from transformers.testing_utils import (
@@ -395,11 +394,6 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
@unittest.skip(reason="Bamba has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_batching_equivalence(self):
# need to disable the tril input mask
orig = self.model_tester.use_input_mask