Offloaded cache: fix generate (#34921)
* fix cache impl * require_torch_gpu * fix mamba * fix copies
This commit is contained in:
committed by
GitHub
parent
57ca9e6d2f
commit
5e8c1d713d
@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
|
||||
@require_torch_gpu
|
||||
@pytest.mark.generate
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_cache_class:
|
||||
self.skipTest(reason="This model does not support the new cache format")
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": 5,
|
||||
"use_cache": True,
|
||||
"cache_implementation": cache_implementation,
|
||||
}
|
||||
|
||||
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
|
||||
# Most cache classes have their own tests except for some that are tested here
|
||||
# The ones here do not need special treatment when passing `cache_implementation`
|
||||
# and are not bound to specific models only
|
||||
new_results = model.generate(**generation_kwargs, **inputs_dict)
|
||||
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_static_cache(self):
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("offloaded",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
def test_generate_text_only_with_cache(self):
|
||||
"""
|
||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||
|
||||
@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("offloaded",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user