Offloaded cache: fix generate (#34921)

* fix cache impl

* require_torch_gpu

* fix mamba

* fix copies
This commit is contained in:
Raushan Turganbay
2024-11-28 15:05:56 +01:00
committed by GitHub
parent 57ca9e6d2f
commit 5e8c1d713d
6 changed files with 91 additions and 19 deletions

View File

@@ -1880,6 +1880,32 @@ class GenerationTesterMixin:
)
)
@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
def test_offloaded_cache_implementation(self, cache_implementation):
"""Tests we can generate by indicating `cache_implementation` for each possible cache class"""
for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
self.skipTest(reason="This model does not support the new cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"use_cache": True,
"cache_implementation": cache_implementation,
}
legacy_results = model.generate(**generation_kwargs, **inputs_dict)
# Most cache classes have their own tests except for some that are tested here
# The ones here do not need special treatment when passing `cache_implementation`
# and are not bound to specific models only
new_results = model.generate(**generation_kwargs, **inputs_dict)
self.assertListEqual(legacy_results.tolist(), new_results.tolist())
@pytest.mark.generate
def test_generate_with_static_cache(self):
"""

View File

@@ -16,7 +16,9 @@
import unittest
import pytest
import requests
from parameterized import parameterized
from transformers import (
AutoProcessor,
@@ -365,6 +367,12 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_model_parallelism(self):
pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Offloaded cache seems to not work with mllama's kv cache type")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature

View File

@@ -567,6 +567,12 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_with_head_masking(self):
pass
@parameterized.expand([("offloaded",)])
@pytest.mark.generate
@unittest.skip(reason="Whisper doesnt work with offloaded cache implementation yet")
def test_offloaded_cache_implementation(self, cache_implementation):
pass
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()