Update ux cb (#39845)

* clenaup

* nits

* updates

* fix logging

* push updates?

* just passexception

* update

* nits

* fix

* add tokencount

* style
This commit is contained in:
Arthur
2025-08-01 16:50:28 +02:00
committed by GitHub
parent 3951d4ad5d
commit 6ea646a03a
4 changed files with 121 additions and 179 deletions

View File

@@ -10,26 +10,27 @@ from transformers.generation import GenerationConfig
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
model_id = "meta-llama/Llama-3.2-3b-Instruct" model_id = "meta-llama/Llama-3.2-3b-Instruct"
model = AutoModelForCausalLM.from_pretrained( model = (
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" AutoModelForCausalLM.from_pretrained(
).eval() model_id,
attn_implementation="paged_attention|kernels-community/flash-attn",
torch_dtype=torch.bfloat16,
)
.eval()
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
generation_config = GenerationConfig( generation_config = GenerationConfig(
max_new_tokens=512, max_new_tokens=512,
# use_cuda_graph=False,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
use_cache=False, do_sample=False,
num_blocks=2048,
block_size=128,
do_sample=True,
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
scheduler="prefill_first",
) )
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version
# --- Example 1: Simple Version using generate_batch ---
print("--- Running CB Generation Example ---") print("--- Running CB Generation Example ---")
@@ -41,19 +42,21 @@ tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
start_time_simple = time.time() start_time_simple = time.time()
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True) model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
batch_outputs = model.generate_batch( batch_outputs = model.generate_batch(
inputs=simple_batch_inputs, inputs=simple_batch_inputs,
generation_config=generation_config, generation_config=generation_config,
) )
end_time_simple = time.time() end_time_simple = time.time()
token_count = 0
for request in batch_outputs: for request in batch_outputs:
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
try: try:
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
token_count += len(batch_outputs[request].generated_tokens[1:])
except Exception as e: except Exception as e:
print(f"Decoding failed for request {request}: {e}") print(f"Decoding failed for request {request}: {e}")
token_count += len(batch_outputs[request].generated_tokens[1:])
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
if len(output_text) > 0: if len(output_text) > 0:
print("-" * 20) print("-" * 20)
@@ -65,7 +68,9 @@ print("-" * 20)
print("--- Finished CB Generation Example ---\n\n") print("--- Finished CB Generation Example ---\n\n")
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") print(
f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s"
)
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version # train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version

View File

@@ -13,9 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import queue import queue
import statistics
import threading import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -29,11 +27,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from tokenizers import Tokenizer from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream from tokenizers.decoders import DecodeStream
from torch.profiler import profile, schedule, tensorboard_trace_handler
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..generation.configuration_utils import GenerationConfig from ..generation.configuration_utils import GenerationConfig
from ..utils.logging import logging
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
@@ -49,9 +47,7 @@ class RequestStatus(Enum):
FAILED = "failed" FAILED = "failed"
# Setup your logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
@dataclass @dataclass
@@ -159,8 +155,8 @@ class PagedAttentionCache:
generation_config: GenerationConfig, generation_config: GenerationConfig,
device: torch.device, device: torch.device,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
num_requests: int = 100,
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
initial_prompt_shapes: Optional[list[list[int]]] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
) -> None: ) -> None:
"""Initialize a paged attention cache for efficient memory usage. """Initialize a paged attention cache for efficient memory usage.
@@ -179,23 +175,6 @@ class PagedAttentionCache:
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads else config.num_key_value_heads
) )
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.num_hidden_layers = config.num_hidden_layers
# Calculate optimal block size and number if not provided
num_blocks = getattr(generation_config, "num_blocks", None)
block_size = getattr(generation_config, "block_size", None)
if num_blocks is None or block_size is None:
logger.info("Calculating optimal block size and number...")
num_blocks, block_size = compute_optimal_blocks(
device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200
)
logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}")
self.block_size = block_size
self.num_blocks = num_blocks
num_key_value_heads = self.num_key_value_heads num_key_value_heads = self.num_key_value_heads
if tp_size is not None and tp_size > 1: if tp_size is not None and tp_size > 1:
if num_key_value_heads % tp_size != 0: if num_key_value_heads % tp_size != 0:
@@ -203,8 +182,33 @@ class PagedAttentionCache:
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}." f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
) )
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly. # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
num_key_value_heads //= tp_size self.num_key_value_heads //= tp_size
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.num_hidden_layers = config.num_hidden_layers
# Calculate optimal block size and number if not provided
num_blocks = getattr(generation_config, "num_blocks", None)
block_size = getattr(generation_config, "block_size", 32)
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
num_blocks, max_batch_tokens = compute_optimal_blocks(
generation_config.max_new_tokens,
block_size=block_size,
head_dim=self.head_dim,
num_layers=self.num_hidden_layers,
num_heads=self.num_key_value_heads,
max_memory_percent=max_memory_percent,
dtype=dtype,
num_blocks=num_blocks,
)
logger.warning(
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
)
self.max_batch_tokens = max_batch_tokens
self.block_size = block_size
self.num_blocks = num_blocks
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim) self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
self.dtype = dtype self.dtype = dtype
@@ -249,7 +253,7 @@ class PagedAttentionCache:
blocks_to_free = self._block_tables.pop(request_id) blocks_to_free = self._block_tables.pop(request_id)
self._free_blocks.extend(blocks_to_free) self._free_blocks.extend(blocks_to_free)
else: else:
logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}") logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}")
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
"""Returns the number of free blocks available.""" """Returns the number of free blocks available."""
@@ -343,7 +347,7 @@ class Scheduler(ABC):
@traced @traced
def has_pending_requests(self) -> bool: def has_pending_requests(self) -> bool:
"""Check if there are requests ready to be processed.""" """Check if there are requests ready to be processed."""
return self.active_requests or self.waiting_requests return len(self.active_requests) or len(self.waiting_requests)
@abstractmethod @abstractmethod
def finish_request(self, request_id: str, evict_from_cache: bool = True): def finish_request(self, request_id: str, evict_from_cache: bool = True):
@@ -595,94 +599,60 @@ class PrefillFirstScheduler(Scheduler):
del self.active_requests[request_id] del self.active_requests[request_id]
def get_device_and_memory():
# Select best available device
if torch.cuda.is_available():
device = torch.device("cuda")
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)
allocated_memory = torch.cuda.memory_allocated(device)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
# MPS memory reporting (PyTorch 2.0+)
total_memory = torch.mps.driver_allocated_memory()
allocated_memory = total_memory - torch.mps.recommended_max_memory()
reserved_memory = 0 # MPS does not track reserved separately
else:
device = torch.device("cpu")
total_memory = None
reserved_memory = 0
allocated_memory = 0
return device, total_memory, reserved_memory, allocated_memory
@traced(standalone=True) @traced(standalone=True)
def compute_optimal_blocks( def compute_optimal_blocks(
device: torch.device, max_num_tokens,
config: PretrainedConfig, block_size,
generation_config: GenerationConfig, head_dim,
inputs: list[list[int]], num_heads,
dtype: torch.dtype = torch.bfloat16, num_layers,
safety_margin: float = 0.9, max_memory_percent=0.9,
median_prefill_length: Optional[int] = None, num_blocks=None,
dtype=torch.float16,
): ):
"""Calculate optimal number and size of blocks for the KV cache. device, total, reserved, allocated = get_device_and_memory()
available_memory = int((total - max(allocated, reserved)) * max_memory_percent)
Args:
device: The device where the model runs
config: The model configuration
generation_config: The generation configuration
inputs: Sample input sequences to estimate memory requirements
dtype: Data type for cache tensors
safety_margin: Fraction of available memory to use
median_prefill_length: Override for median prefill length calculation
Returns:
Tuple of (num_blocks, block_size)
"""
# Extract model dimensions
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
num_hidden_layers = getattr(config, "num_hidden_layers", 40)
# Get available device memory
if device.type == "cuda":
device_properties = torch.cuda.get_device_properties(device)
total_memory = device_properties.total_memory
allocated_memory = torch.cuda.memory_allocated(device)
reserved_memory = torch.cuda.memory_reserved(device)
available_memory = total_memory - max(allocated_memory, reserved_memory)
elif device.type == "mps":
logger.warning("MPS memory estimation is approximate. Using conservative defaults.")
return 2048, 256
else:
logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.")
return 32, 128
# Apply safety margin
available_memory = int(available_memory * safety_margin)
if available_memory <= 0:
logger.warning("Not enough available memory. Using minimum configuration.")
return 8, 128 # Minimum viable configuration
# Calculate memory per token
dtype_size = torch.tensor([], dtype=dtype).element_size() dtype_size = torch.tensor([], dtype=dtype).element_size()
memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers
if num_blocks is not None:
# Estimate sequence length requirements # TODO
tokens_to_generate = getattr(generation_config, "max_new_tokens") or 20 max_possible_concurrent_requests = num_blocks * bytes_per_token
# FIXME: forgot to add the inintial prompt length in the mix....
if median_prefill_length is None and inputs: max_possible_concurrent_requests = int(
non_empty_inputs = [len(seq) for seq in inputs if seq] available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4)
median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64
elif median_prefill_length is None:
median_prefill_length = 64 # Reasonable default if no inputs provided
# Total sequence length including generated tokens
seq_length = median_prefill_length + tokens_to_generate
# Calculate block parameters
MIN_BLOCK_SIZE = 16
# Estimate number of concurrent sequences
per_sequence_memory = seq_length * memory_per_token
max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory))
# Total tokens that can fit in memory
total_tokens = available_memory // memory_per_token
# Calculate block size (rounded to power of 2)
initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2))
block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2
# Calculate number of blocks
num_blocks = max(1, total_tokens // block_size)
logger.info(
f"Optimal cache: {num_blocks} blocks of size {block_size} "
f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})"
) )
if max_possible_concurrent_requests <= 0:
return int(num_blocks), int(block_size) logger.warning("you are trying to generate a bit too many tokens")
max_possible_concurrent_requests = 32
max_concurrent_tokens = min(64, max_possible_concurrent_requests)
# FIXME: Optimal means uses all memory
optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64)
return optimal_num_blocks, max_concurrent_tokens
@dataclass @dataclass
@@ -775,11 +745,9 @@ class ContinuousBatchProcessor:
self.requests_in_batch: list[RequestState] = [] self.requests_in_batch: list[RequestState] = []
# Get batch size parameters from generation config
self._configure_batch_parameters()
# Set up metrics collector # Set up metrics collector
self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens) self.max_batch_tokens = cache.max_batch_tokens
self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
self.setup_static_tensors() self.setup_static_tensors()
@@ -847,25 +815,6 @@ class ContinuousBatchProcessor:
+ self.get_model_kwargs().__repr__() + self.get_model_kwargs().__repr__()
) )
@traced(standalone=True)
def _configure_batch_parameters(self):
"""Set up batch processing parameters based on generation config."""
# Calculate total cache capacity
total_cache_tokens = self.cache.num_blocks * self.cache.block_size
# Get or calculate max tokens per batch
user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None)
if user_batch_tokens is not None:
self.max_batch_tokens = user_batch_tokens
else:
# Default to 1/8 of total cache capacity, adjusted for context
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len)
self.max_batch_tokens = max(64, recommended_batch_size)
# Context length and EOS token
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
@traced @traced
def _get_new_requests(self): def _get_new_requests(self):
"""Pull new requests from the input queue and add to waiting list.""" """Pull new requests from the input queue and add to waiting list."""
@@ -1041,6 +990,8 @@ class ContinuousBatchProcessor:
self._maybe_send_output(state, token) self._maybe_send_output(state, token)
elif state.status == RequestStatus.PREFILLING_SPLIT: elif state.status == RequestStatus.PREFILLING_SPLIT:
state.status = RequestStatus.SPLIT_PENDING_REMAINDER state.status = RequestStatus.SPLIT_PENDING_REMAINDER
if self.cache.get_num_free_blocks() == 0:
raise ValueError("No more free blocks")
@traced @traced
def has_pending_requests(self) -> bool: def has_pending_requests(self) -> bool:
@@ -1062,7 +1013,9 @@ class ContinuousBatchProcessor:
Args: Args:
error: The error to report in the failure message error: The error to report in the failure message
""" """
for state in self.scheduler.active_requests.values():
requests = list(self.scheduler.active_requests.values())
for state in requests:
self._handle_request_error(error, state) self._handle_request_error(error, state)
self.scheduler.finish_request(state.request_id) self.scheduler.finish_request(state.request_id)
@@ -1296,6 +1249,7 @@ class ContinuousBatchingManager:
self.generation_config, self.generation_config,
self.model.device, self.model.device,
self.model.dtype, self.model.dtype,
num_requests=len(self.input_queue.queue),
tp_size=getattr(self.model, "tp_size"), tp_size=getattr(self.model, "tp_size"),
) )
@@ -1324,33 +1278,10 @@ class ContinuousBatchingManager:
) )
self.batch_processor = batch_processor self.batch_processor = batch_processor
is_first = True is_first = True
while (not self.stop_event.is_set()) or batch_processor.has_pending_requests():
if self.profile: self._inner_generation_loop(batch_processor, is_first)
tracing_schedule = schedule(skip_first=2, warmup=1, active=1, repeat=3, wait=1) if is_first:
trace_handler = tensorboard_trace_handler( is_first = False
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
)
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with profile(
activities=activities,
schedule=tracing_schedule,
on_trace_ready=trace_handler,
record_shapes=False,
with_stack=True,
) as prof:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
prof.step()
else:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
except Exception as e: except Exception as e:
logger.error(f"Error in generation loop: {e}", exc_info=True) logger.error(f"Error in generation loop: {e}", exc_info=True)
@@ -1363,6 +1294,8 @@ class ContinuousBatchingManager:
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
batch_processor.prepare_next_batch() batch_processor.prepare_next_batch()
device, total, reserved, allocated = get_device_and_memory()
logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
if torch.cuda.is_available() and self.use_cuda_graph: if torch.cuda.is_available() and self.use_cuda_graph:
if is_first: if is_first:
self.warmup(batch_processor) self.warmup(batch_processor)
@@ -1502,6 +1435,7 @@ class ContinuousMixin:
results[req_id] = result results[req_id] = result
finished_count += 1 finished_count += 1
pbar.update(1) pbar.update(1)
logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens))
else: else:
if not manager.is_running(): if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.") logger.error("Generation thread terminated unexpectedly.")

View File

@@ -4,8 +4,11 @@ from ..generation.continuous_batching import PagedAttentionCache
from ..utils import is_flash_attn_2_available from ..utils import is_flash_attn_2_available
if is_flash_attn_2_available(): try:
from flash_attn import flash_attn_varlen_func # noqa: F401 if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func # noqa: F401
except Exception:
pass
def paged_attention_forward( def paged_attention_forward(

View File

@@ -2705,9 +2705,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
kernel_function = partial(attention_wrapper, implementation=kernel) kernel_function = partial(attention_wrapper, implementation=kernel)
elif kernel_name is not None: elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name) kernel_function = getattr(kernel, kernel_name)
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function) ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
ALL_MASK_ATTENTION_FUNCTIONS.register( ALL_MASK_ATTENTION_FUNCTIONS.register(
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
) )
except Exception as e: except Exception as e:
logger.warning_once( logger.warning_once(
@@ -2715,8 +2715,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
"default attention implementation instead (sdpa if available, eager otherwise)." "default attention implementation instead (sdpa if available, eager otherwise)."
) )
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case attn_implementation = "sdpa" # Try to fallback to sdpa in this case
return applicable_attn_implementation return attn_implementation
else: else:
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check) return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)