Offloaded cache: fix generate (#34921)

* fix cache impl

* require_torch_gpu

* fix mamba

* fix copies
This commit is contained in:
Raushan Turganbay
2024-11-28 15:05:56 +01:00
committed by GitHub
parent 57ca9e6d2f
commit 5e8c1d713d
6 changed files with 91 additions and 19 deletions

View File

@@ -16,7 +16,9 @@
import unittest
import pytest
import requests
from parameterized import parameterized
from transformers import (
AutoProcessor,
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_model_parallelism(self):
pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature

View File

@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_with_head_masking(self):
pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()