Cache: models return input cache type (#30716)

This commit is contained in:
Joao Gante
2024-05-08 18:26:34 +01:00
committed by GitHub
parent 71c1985069
commit f26e407370
11 changed files with 30 additions and 70 deletions

View File

@@ -17,8 +17,6 @@
import unittest
from parameterized import parameterized
from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_tied_weights_keys(self):
pass
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):