[cache] add a test to confirm we can use cache at train time (#35709)

* add test

* augment test as suggested

* Update tests/utils/test_modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* rerun tests

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2025-01-16 17:02:34 +00:00
committed by GitHub
parent 57bf1a12a0
commit aeeceb9916

View File

@@ -37,6 +37,7 @@ from transformers import (
AutoModel, AutoModel,
AutoModelForImageClassification, AutoModelForImageClassification,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
DynamicCache,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
OwlViTForObjectDetection, OwlViTForObjectDetection,
PretrainedConfig, PretrainedConfig,
@@ -1790,6 +1791,43 @@ class ModelUtilsTest(TestCasePlus):
) )
self.assertTrue(check_models_equal(model, model_loaded)) self.assertTrue(check_models_equal(model, model_loaded))
def test_cache_when_needed_at_train_time(self):
"""
Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache
is at train time used if we request it. Related issue: #35648
"""
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL)
model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
# By default it is not training, we have to set it
self.assertFalse(model.training)
model.train()
# If we set `use_cache=True` while training, then a cache is returned
model_outputs = model(**model_inputs, use_cache=True)
self.assertIsInstance(model_outputs.past_key_values, DynamicCache)
self.assertTrue(model.training)
# simulate injecting virtual tokens like in prefix tuning
num_virtual_tokens = 3
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
model_inputs["attention_mask"] = torch.cat(
(
model_inputs["attention_mask"],
torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device),
),
dim=1,
)
model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True)
self.assertTrue(model.training)
# We can also disable the cache to skip a few operations, if the training loop doesn't need cache
model_outputs = model(**model_inputs, use_cache=False)
self.assertIsNone(model_outputs.past_key_values)
self.assertTrue(model.training)
@slow @slow
@require_torch @require_torch