🚨 Bloom support for cache class (#31445)

* bloom dynamic cache

* bloom follows standard cache format

* no skips for bloom anymore

* use cache position when possible

* clean up

* codestyle

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pr comments

* isinstance fix

* address comments

* make musicgen test happy

* [run-slow] bloom

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2024-07-29 10:58:59 +05:00
committed by GitHub
parent 44f6fdd74f
commit f739687684
6 changed files with 228 additions and 194 deletions

View File

@@ -1096,7 +1096,6 @@ class GenerationTesterMixin:
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bloom",
"ctrl",
"gptbigcode",
"transo_xl",
@@ -1878,7 +1877,7 @@ class GenerationTesterMixin:
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)