[cache] add a test to confirm we can use cache at train time (#35709)
* add test * augment test as suggested * Update tests/utils/test_modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * rerun tests --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -37,6 +37,7 @@ from transformers import (
|
|||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
DynamicCache,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
OwlViTForObjectDetection,
|
OwlViTForObjectDetection,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
@@ -1790,6 +1791,43 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
)
|
)
|
||||||
self.assertTrue(check_models_equal(model, model_loaded))
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_cache_when_needed_at_train_time(self):
|
||||||
|
"""
|
||||||
|
Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache
|
||||||
|
is at train time used if we request it. Related issue: #35648
|
||||||
|
"""
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL)
|
||||||
|
model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
|
|
||||||
|
# By default it is not training, we have to set it
|
||||||
|
self.assertFalse(model.training)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# If we set `use_cache=True` while training, then a cache is returned
|
||||||
|
model_outputs = model(**model_inputs, use_cache=True)
|
||||||
|
self.assertIsInstance(model_outputs.past_key_values, DynamicCache)
|
||||||
|
self.assertTrue(model.training)
|
||||||
|
|
||||||
|
# simulate injecting virtual tokens like in prefix tuning
|
||||||
|
num_virtual_tokens = 3
|
||||||
|
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
model_inputs["attention_mask"] = torch.cat(
|
||||||
|
(
|
||||||
|
model_inputs["attention_mask"],
|
||||||
|
torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True)
|
||||||
|
self.assertTrue(model.training)
|
||||||
|
|
||||||
|
# We can also disable the cache to skip a few operations, if the training loop doesn't need cache
|
||||||
|
model_outputs = model(**model_inputs, use_cache=False)
|
||||||
|
self.assertIsNone(model_outputs.past_key_values)
|
||||||
|
self.assertTrue(model.training)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user