Cache docs: update (#32929)

* some changes

* more updates

* fix cache copy

* nits

* nits

* add tests
This commit is contained in:
Raushan Turganbay
2024-09-04 12:05:31 +02:00
committed by GitHub
parent 35f72ebf47
commit ebbe8d8014
3 changed files with 114 additions and 39 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
from packaging import version
@@ -616,3 +617,34 @@ class CacheIntegrationTest(unittest.TestCase):
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
assert offloaded_peak_memory < original_peak_memory
@require_torch_gpu
def test_cache_copy(self):
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
prompt_cache = StaticCache(
config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16
)
INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be abel to copy
with torch.no_grad():
prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values
prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
responses = []
for prompt in prompts:
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values, max_new_tokens=40)
response = tokenizer.batch_decode(outputs)[0]
responses.append(response)
EXPECTED_DECODED_TEXT = [
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip
self.assertTrue(responses == EXPECTED_DECODED_TEXT)