Offloaded cache: fix generate (#34921)
* fix cache impl * require_torch_gpu * fix mamba * fix copies
This commit is contained in:
committed by
GitHub
parent
57ca9e6d2f
commit
5e8c1d713d
@@ -16,7 +16,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("offloaded",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
def test_generate_text_only_with_cache(self):
|
||||
"""
|
||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||
|
||||
@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("offloaded",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user