Cache: models return input cache type (#30716)

This commit is contained in:
Joao Gante
2024-05-08 18:26:34 +01:00
committed by GitHub
parent 71c1985069
commit f26e407370
11 changed files with 30 additions and 70 deletions

View File

@@ -16,8 +16,6 @@
import unittest
from parameterized import parameterized
from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("TODO @gante fix this for Cohere")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

View File

@@ -17,8 +17,6 @@
import unittest
from parameterized import parameterized
from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_tied_weights_keys(self):
pass
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):

View File

@@ -17,7 +17,6 @@ import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import (
@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass

View File

@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch_gpu
class LlamaIntegrationTest(unittest.TestCase):

View File

@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@unittest.skip("TODO @gante fix this for OLMo")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch
class OlmoIntegrationTest(unittest.TestCase):

View File

@@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """
import unittest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("Recurrent gemma does not use legacy cache")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_save_load_fast_init_from_base(self):
pass