[generate] can instantiate GenerationConfig(cache_implementation="static") (#35679)
fix failing instantiation
This commit is contained in:
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
|
||||||
NEEDS_CACHE_CONFIG = {}
|
CACHE_CONFIG_MAPPING = {}
|
||||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
|
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
|
||||||
QUANT_BACKEND_CLASSES_MAPPING = {}
|
QUANT_BACKEND_CLASSES_MAPPING = {}
|
||||||
ALL_CACHE_IMPLEMENTATIONS = []
|
ALL_CACHE_IMPLEMENTATIONS = []
|
||||||
@@ -62,8 +62,8 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor
|
||||||
|
|
||||||
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
|
CACHE_CONFIG_MAPPING["quantized"] = QuantizedCacheConfig
|
||||||
NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig
|
CACHE_CONFIG_MAPPING["static"] = StaticCacheConfig
|
||||||
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
NEED_SETUP_CACHE_CLASSES_MAPPING = {
|
||||||
"static": StaticCache,
|
"static": StaticCache,
|
||||||
"offloaded_static": OffloadedStaticCache,
|
"offloaded_static": OffloadedStaticCache,
|
||||||
@@ -73,7 +73,7 @@ if is_torch_available():
|
|||||||
}
|
}
|
||||||
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
|
||||||
ALL_CACHE_IMPLEMENTATIONS = (
|
ALL_CACHE_IMPLEMENTATIONS = (
|
||||||
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(NEEDS_CACHE_CONFIG.keys()) + ["offloaded"]
|
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -409,11 +409,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.use_cache = kwargs.pop("use_cache", True)
|
self.use_cache = kwargs.pop("use_cache", True)
|
||||||
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
self.cache_implementation = kwargs.pop("cache_implementation", None)
|
||||||
self.cache_config = kwargs.pop("cache_config", None)
|
self.cache_config = kwargs.pop("cache_config", None)
|
||||||
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
|
if self.cache_implementation is not None and self.cache_implementation in CACHE_CONFIG_MAPPING:
|
||||||
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
|
cache_config_class = CACHE_CONFIG_MAPPING[self.cache_implementation]
|
||||||
if self.cache_config is None:
|
if isinstance(self.cache_config, dict):
|
||||||
self.cache_config = cache_config_class()
|
|
||||||
elif isinstance(self.cache_config, dict):
|
|
||||||
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
self.cache_config = cache_config_class.from_dict(self.cache_config)
|
||||||
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
|
||||||
|
|
||||||
@@ -766,7 +764,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
f"{ALL_CACHE_IMPLEMENTATIONS}"
|
f"{ALL_CACHE_IMPLEMENTATIONS}"
|
||||||
)
|
)
|
||||||
if self.cache_config is not None:
|
if self.cache_config is not None:
|
||||||
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
|
cache_class = CACHE_CONFIG_MAPPING.get(self.cache_implementation)
|
||||||
if cache_class is None:
|
if cache_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You provided a `cache_config` but the cache implementation you are using "
|
"You provided a `cache_config` but the cache implementation you are using "
|
||||||
|
|||||||
@@ -259,6 +259,12 @@ class GenerationConfigTest(unittest.TestCase):
|
|||||||
config = GenerationConfig()
|
config = GenerationConfig()
|
||||||
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
||||||
|
|
||||||
|
def test_static_cache_without_cache_config(self):
|
||||||
|
"""Regression test for #35026 -- static cache should work without a cache config."""
|
||||||
|
config = GenerationConfig(cache_implementation="static")
|
||||||
|
self.assertEqual(config.cache_implementation, "static")
|
||||||
|
self.assertEqual(config.cache_config, None)
|
||||||
|
|
||||||
|
|
||||||
class GenerationConfigSerializationTest(unittest.TestCase):
|
class GenerationConfigSerializationTest(unittest.TestCase):
|
||||||
def test_serialize_generation_sequence_bias(self):
|
def test_serialize_generation_sequence_bias(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user