[generate] return Cache object even if passed in a legacy format (#35673)

* generate returns a Cache object by default

* fix tests

* fix test for encoder-decoder models
This commit is contained in:
Joao Gante
2025-01-16 17:06:24 +00:00
committed by GitHub
parent 2818307e93
commit 94af1c0aa2
9 changed files with 36 additions and 156 deletions

View File

@@ -19,7 +19,6 @@ import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoTokenizer, ZambaConfig, is_torch_available
from transformers.testing_utils import (
@@ -551,11 +550,6 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
"""
self.skipTest(reason="Zamba flash attention does not support right padding")
@unittest.skip(reason="Zamba has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class ZambaModelIntegrationTest(unittest.TestCase):