From aeeceb99167ad95f5e1e55400ad894fe101c2e18 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 16 Jan 2025 17:02:34 +0000 Subject: [PATCH] [cache] add a test to confirm we can use cache at train time (#35709) * add test * augment test as suggested * Update tests/utils/test_modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * rerun tests --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- tests/utils/test_modeling_utils.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index b8e10ff8ad..63f8e7ec46 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -37,6 +37,7 @@ from transformers import ( AutoModel, AutoModelForImageClassification, AutoModelForSequenceClassification, + DynamicCache, LlavaForConditionalGeneration, OwlViTForObjectDetection, PretrainedConfig, @@ -1790,6 +1791,43 @@ class ModelUtilsTest(TestCasePlus): ) self.assertTrue(check_models_equal(model, model_loaded)) + def test_cache_when_needed_at_train_time(self): + """ + Some fine-tuning methods require the use of cache, like prefix tuning in PEFT. This test checks that a cache + is at train time used if we request it. Related issue: #35648 + """ + model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL) + tokenizer = AutoTokenizer.from_pretrained(TINY_MISTRAL) + model_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + + # By default it is not training, we have to set it + self.assertFalse(model.training) + model.train() + + # If we set `use_cache=True` while training, then a cache is returned + model_outputs = model(**model_inputs, use_cache=True) + self.assertIsInstance(model_outputs.past_key_values, DynamicCache) + self.assertTrue(model.training) + + # simulate injecting virtual tokens like in prefix tuning + num_virtual_tokens = 3 + past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + model_inputs["attention_mask"] = torch.cat( + ( + model_inputs["attention_mask"], + torch.ones(1, num_virtual_tokens).to(model_inputs["attention_mask"].device), + ), + dim=1, + ) + model_outputs = model(**model_inputs, past_key_values=past_key_values, use_cache=True) + self.assertTrue(model.training) + + # We can also disable the cache to skip a few operations, if the training loop doesn't need cache + model_outputs = model(**model_inputs, use_cache=False) + self.assertIsNone(model_outputs.past_key_values) + self.assertTrue(model.training) + @slow @require_torch