Add CB (#38085)
* stash for now * initial commit * small updated * up * up * works! * nits and fixes * don't loop too much * finish working example * update * fix the small freeblocks issue * feat: stream inputs to continuous batch * fix: update attn from `eager` to `sdpa` * refactor: fmt * refactor: cleanup unnecessary code * feat: add `update` fn to `PagedAttentionCache` * feat: broken optimal block size computation * fix: debugging invalid cache logic * fix: attention mask * refactor: use custom prompts for example * feat: add streaming output * fix: prefill split refactor: add doc strings and unsound/redundant logic fix: compute optimal blocks logic * fix: send decoded tokens when `prefilling_split` -> `decoding` * refactor: move logic to appropriate parent class * fix: remove truncation as we split prefilling anyways refactor: early return when we have enough selected requests * feat: add paged attention forward * push Ggraoh> * add paged sdpa * update * btter mps defaults * feat: add progress bar for `generate_batch` * feat: add opentelemetry metrics (ttft + batch fill %age) * feat: add tracing * Add cuda graphs (#38059) * draft cudagraphs addition * nits * styling * update * fix * kinda draft of what it should look like * fixes * lol * not sure why inf everywhere * can generate but output is shit * some fixes * we should have a single device synch * broken outputs but it does run * refactor * updates * updates with some fixes * fix mask causality * another commit that casts after * add error * simplify example * update * updates * revert llama changes * fix merge conflicts * fix: tracing and metrics * my updates * update script default values * fix block allocation issue * fix prefill split attnetion mask * no bugs * add paged eager * fix * update * style * feat: add pytorch traces * fix * fix * refactor: remove pytorch profiler data * style * nits * cleanup * draft test file * fix * fix * fix paged and graphs * small renamings * cleanups and push * refactor: move tracing and metrics logic to utils * refactor: trace more blocks of code * nits * nits * update * to profile or not to profile * refactor: create new output object * causal by default * cleanup but generations are still off for IDK what reason * simplifications but not running still * this does work. * small quality of life updates * nits * updaet * fix the scheduler * fix warning * ol * fully fixed * nits * different generation parameters * nice * just style * feat: add cache memory usage * feat: add kv cache free memory * feat: add active/waiting count & req latency * do the sampling * fix: synchronize CUDA only if available and improve error handling in ContinuousBatchingManager * fix on mps * feat: add dashboard & histogram buckets * perf: improve waiting reqs data structures * attempt to compile, but we should only do it on mps AFAIK * feat: decouple scheduling logic * just a draft * c;eanup and fixup * optional * style * update * update * remove the draft documentation * fix import as well * update * fix the test * style doomed --------- Co-authored-by: Luc Georges <luc.sydney.georges@gmail.com>
This commit is contained in:
86
tests/generation/test_paged_attention.py
Normal file
86
tests/generation/test_paged_attention.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from transformers.testing_utils import require_flash_attn, require_torch_gpu, run_slow
|
||||
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"A man is a walking his dog down the street, and a the turn he sees",
|
||||
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
|
||||
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
|
||||
"Please fill in the form to",
|
||||
"For safety reasons, the train is stopped in the middle of the",
|
||||
]
|
||||
|
||||
_EXPECTED_OUTPUTS = [
|
||||
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes up a conversation, and they quickly discover that they have a lot in common. They exchange numbers and",
|
||||
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n## Step 2: Determine the taste and nutritional value of the fruit\nThe fruit is described as sweet",
|
||||
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer is not a straightforward one, and it requires some lateral thinking to arrive at the correct solution.",
|
||||
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]\n[Your Message]\n\nWe are looking forward to hearing from you!\n\n[Insert Contact Information]\n\nNote:",
|
||||
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min",
|
||||
]
|
||||
|
||||
|
||||
@run_slow
|
||||
@require_torch_gpu
|
||||
@require_flash_attn
|
||||
class TestBatchGeneration(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-3b-Instruct", torch_dtype="bfloat16", device_map="auto"
|
||||
).eval()
|
||||
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
|
||||
|
||||
if cls.tokenizer.pad_token is None:
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
cls.model.config.pad_token_id = cls.model.config.eos_token_id
|
||||
|
||||
cls.model.use_cache = False
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("eager_paged", 64, 128, 64),
|
||||
("sdpa_paged", 32, 256, 128),
|
||||
("paged_attention", 16, 512, 256),
|
||||
("flex_paged", 64, 128, 64),
|
||||
]
|
||||
)
|
||||
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
||||
self.model.config.attn_implementation = attn_impl
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=50,
|
||||
top_k=0,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
max_batch_tokens=max_batch_tokens,
|
||||
)
|
||||
|
||||
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
|
||||
batch_inputs = list(tokenized["input_ids"])
|
||||
|
||||
start = time.time()
|
||||
batch_outputs = self.model.generate_batch(
|
||||
inputs=batch_inputs,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
end = time.time()
|
||||
print(
|
||||
f"\n[{attn_impl}] Batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
|
||||
)
|
||||
|
||||
for i, req_id in enumerate(batch_outputs):
|
||||
generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip()
|
||||
expected = _EXPECTED_OUTPUTS[i].strip()
|
||||
self.assertTrue(
|
||||
generated.startswith(expected),
|
||||
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
|
||||
)
|
||||
Reference in New Issue
Block a user