Cache docs: update (#32929)
* some changes * more updates * fix cache copy * nits * nits * add tests
This commit is contained in:
committed by
GitHub
parent
35f72ebf47
commit
ebbe8d8014
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
@@ -616,3 +617,34 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
model.generate(generation_config=offloaded, **inputs)
|
||||
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
|
||||
assert offloaded_peak_memory < original_peak_memory
|
||||
|
||||
@require_torch_gpu
|
||||
def test_cache_copy(self):
|
||||
model_name = "microsoft/Phi-3-mini-4k-instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
|
||||
|
||||
prompt_cache = StaticCache(
|
||||
config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
INITIAL_PROMPT = "You are a helpful assistant. "
|
||||
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
|
||||
# This is the common prompt cached, we need to run forward without grad to be abel to copy
|
||||
with torch.no_grad():
|
||||
prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values
|
||||
|
||||
prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
|
||||
responses = []
|
||||
for prompt in prompts:
|
||||
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
|
||||
past_key_values = copy.deepcopy(prompt_cache)
|
||||
outputs = model.generate(**new_inputs, past_key_values=past_key_values, max_new_tokens=40)
|
||||
response = tokenizer.batch_decode(outputs)[0]
|
||||
responses.append(response)
|
||||
|
||||
EXPECTED_DECODED_TEXT = [
|
||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||
] # fmt: skip
|
||||
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
|
||||
|
||||
Reference in New Issue
Block a user