Update ux cb (#39845)

* clenaup

* nits

* updates

* fix logging

* push updates?

* just passexception

* update

* nits

* fix

* add tokencount

* style
This commit is contained in:
Arthur
2025-08-01 16:50:28 +02:00
committed by GitHub
parent 3951d4ad5d
commit 6ea646a03a
4 changed files with 121 additions and 179 deletions

View File

@@ -10,26 +10,27 @@ from transformers.generation import GenerationConfig
torch.set_float32_matmul_precision("high")
model_id = "meta-llama/Llama-3.2-3b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
).eval()
model = (
AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="paged_attention|kernels-community/flash-attn",
torch_dtype=torch.bfloat16,
)
.eval()
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
generation_config = GenerationConfig(
max_new_tokens=512,
# use_cuda_graph=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=2048,
block_size=128,
do_sample=True,
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
scheduler="prefill_first",
do_sample=False,
)
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
# --- Example 1: Simple Version using generate_batch ---
train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version
print("--- Running CB Generation Example ---")
@@ -41,19 +42,21 @@ tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
start_time_simple = time.time()
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True)
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
batch_outputs = model.generate_batch(
inputs=simple_batch_inputs,
generation_config=generation_config,
)
end_time_simple = time.time()
token_count = 0
for request in batch_outputs:
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
try:
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
token_count += len(batch_outputs[request].generated_tokens[1:])
except Exception as e:
print(f"Decoding failed for request {request}: {e}")
token_count += len(batch_outputs[request].generated_tokens[1:])
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
if len(output_text) > 0:
print("-" * 20)
@@ -65,7 +68,9 @@ print("-" * 20)
print("--- Finished CB Generation Example ---\n\n")
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
print(
f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s"
)
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version

View File

@@ -13,9 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import queue
import statistics
import threading
import time
from abc import ABC, abstractmethod
@@ -29,11 +27,11 @@ import torch
import torch.nn as nn
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from torch.profiler import profile, schedule, tensorboard_trace_handler
from tqdm import tqdm
from ..configuration_utils import PretrainedConfig
from ..generation.configuration_utils import GenerationConfig
from ..utils.logging import logging
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
@@ -49,9 +47,7 @@ class RequestStatus(Enum):
FAILED = "failed"
# Setup your logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
@dataclass
@@ -159,8 +155,8 @@ class PagedAttentionCache:
generation_config: GenerationConfig,
device: torch.device,
dtype: torch.dtype = torch.float16,
num_requests: int = 100,
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
initial_prompt_shapes: Optional[list[list[int]]] = None,
tp_size: Optional[int] = None,
) -> None:
"""Initialize a paged attention cache for efficient memory usage.
@@ -179,23 +175,6 @@ class PagedAttentionCache:
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.num_hidden_layers = config.num_hidden_layers
# Calculate optimal block size and number if not provided
num_blocks = getattr(generation_config, "num_blocks", None)
block_size = getattr(generation_config, "block_size", None)
if num_blocks is None or block_size is None:
logger.info("Calculating optimal block size and number...")
num_blocks, block_size = compute_optimal_blocks(
device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200
)
logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}")
self.block_size = block_size
self.num_blocks = num_blocks
num_key_value_heads = self.num_key_value_heads
if tp_size is not None and tp_size > 1:
if num_key_value_heads % tp_size != 0:
@@ -203,8 +182,33 @@ class PagedAttentionCache:
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
num_key_value_heads //= tp_size
self.num_key_value_heads //= tp_size
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.num_hidden_layers = config.num_hidden_layers
# Calculate optimal block size and number if not provided
num_blocks = getattr(generation_config, "num_blocks", None)
block_size = getattr(generation_config, "block_size", 32)
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
num_blocks, max_batch_tokens = compute_optimal_blocks(
generation_config.max_new_tokens,
block_size=block_size,
head_dim=self.head_dim,
num_layers=self.num_hidden_layers,
num_heads=self.num_key_value_heads,
max_memory_percent=max_memory_percent,
dtype=dtype,
num_blocks=num_blocks,
)
logger.warning(
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
)
self.max_batch_tokens = max_batch_tokens
self.block_size = block_size
self.num_blocks = num_blocks
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
self.dtype = dtype
@@ -249,7 +253,7 @@ class PagedAttentionCache:
blocks_to_free = self._block_tables.pop(request_id)
self._free_blocks.extend(blocks_to_free)
else:
logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}")
logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}")
def get_num_free_blocks(self) -> int:
"""Returns the number of free blocks available."""
@@ -343,7 +347,7 @@ class Scheduler(ABC):
@traced
def has_pending_requests(self) -> bool:
"""Check if there are requests ready to be processed."""
return self.active_requests or self.waiting_requests
return len(self.active_requests) or len(self.waiting_requests)
@abstractmethod
def finish_request(self, request_id: str, evict_from_cache: bool = True):
@@ -595,94 +599,60 @@ class PrefillFirstScheduler(Scheduler):
del self.active_requests[request_id]
def get_device_and_memory():
# Select best available device
if torch.cuda.is_available():
device = torch.device("cuda")
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)
allocated_memory = torch.cuda.memory_allocated(device)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
# MPS memory reporting (PyTorch 2.0+)
total_memory = torch.mps.driver_allocated_memory()
allocated_memory = total_memory - torch.mps.recommended_max_memory()
reserved_memory = 0 # MPS does not track reserved separately
else:
device = torch.device("cpu")
total_memory = None
reserved_memory = 0
allocated_memory = 0
return device, total_memory, reserved_memory, allocated_memory
@traced(standalone=True)
def compute_optimal_blocks(
device: torch.device,
config: PretrainedConfig,
generation_config: GenerationConfig,
inputs: list[list[int]],
dtype: torch.dtype = torch.bfloat16,
safety_margin: float = 0.9,
median_prefill_length: Optional[int] = None,
max_num_tokens,
block_size,
head_dim,
num_heads,
num_layers,
max_memory_percent=0.9,
num_blocks=None,
dtype=torch.float16,
):
"""Calculate optimal number and size of blocks for the KV cache.
device, total, reserved, allocated = get_device_and_memory()
available_memory = int((total - max(allocated, reserved)) * max_memory_percent)
Args:
device: The device where the model runs
config: The model configuration
generation_config: The generation configuration
inputs: Sample input sequences to estimate memory requirements
dtype: Data type for cache tensors
safety_margin: Fraction of available memory to use
median_prefill_length: Override for median prefill length calculation
Returns:
Tuple of (num_blocks, block_size)
"""
# Extract model dimensions
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
num_hidden_layers = getattr(config, "num_hidden_layers", 40)
# Get available device memory
if device.type == "cuda":
device_properties = torch.cuda.get_device_properties(device)
total_memory = device_properties.total_memory
allocated_memory = torch.cuda.memory_allocated(device)
reserved_memory = torch.cuda.memory_reserved(device)
available_memory = total_memory - max(allocated_memory, reserved_memory)
elif device.type == "mps":
logger.warning("MPS memory estimation is approximate. Using conservative defaults.")
return 2048, 256
else:
logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.")
return 32, 128
# Apply safety margin
available_memory = int(available_memory * safety_margin)
if available_memory <= 0:
logger.warning("Not enough available memory. Using minimum configuration.")
return 8, 128 # Minimum viable configuration
# Calculate memory per token
dtype_size = torch.tensor([], dtype=dtype).element_size()
memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches
# Estimate sequence length requirements
tokens_to_generate = getattr(generation_config, "max_new_tokens") or 20
if median_prefill_length is None and inputs:
non_empty_inputs = [len(seq) for seq in inputs if seq]
median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64
elif median_prefill_length is None:
median_prefill_length = 64 # Reasonable default if no inputs provided
# Total sequence length including generated tokens
seq_length = median_prefill_length + tokens_to_generate
# Calculate block parameters
MIN_BLOCK_SIZE = 16
# Estimate number of concurrent sequences
per_sequence_memory = seq_length * memory_per_token
max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory))
# Total tokens that can fit in memory
total_tokens = available_memory // memory_per_token
# Calculate block size (rounded to power of 2)
initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2))
block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2
# Calculate number of blocks
num_blocks = max(1, total_tokens // block_size)
logger.info(
f"Optimal cache: {num_blocks} blocks of size {block_size} "
f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})"
bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers
if num_blocks is not None:
# TODO
max_possible_concurrent_requests = num_blocks * bytes_per_token
# FIXME: forgot to add the inintial prompt length in the mix....
max_possible_concurrent_requests = int(
available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4)
)
return int(num_blocks), int(block_size)
if max_possible_concurrent_requests <= 0:
logger.warning("you are trying to generate a bit too many tokens")
max_possible_concurrent_requests = 32
max_concurrent_tokens = min(64, max_possible_concurrent_requests)
# FIXME: Optimal means uses all memory
optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64)
return optimal_num_blocks, max_concurrent_tokens
@dataclass
@@ -775,11 +745,9 @@ class ContinuousBatchProcessor:
self.requests_in_batch: list[RequestState] = []
# Get batch size parameters from generation config
self._configure_batch_parameters()
# Set up metrics collector
self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens)
self.max_batch_tokens = cache.max_batch_tokens
self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
self.setup_static_tensors()
@@ -847,25 +815,6 @@ class ContinuousBatchProcessor:
+ self.get_model_kwargs().__repr__()
)
@traced(standalone=True)
def _configure_batch_parameters(self):
"""Set up batch processing parameters based on generation config."""
# Calculate total cache capacity
total_cache_tokens = self.cache.num_blocks * self.cache.block_size
# Get or calculate max tokens per batch
user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None)
if user_batch_tokens is not None:
self.max_batch_tokens = user_batch_tokens
else:
# Default to 1/8 of total cache capacity, adjusted for context
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len)
self.max_batch_tokens = max(64, recommended_batch_size)
# Context length and EOS token
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
@traced
def _get_new_requests(self):
"""Pull new requests from the input queue and add to waiting list."""
@@ -1041,6 +990,8 @@ class ContinuousBatchProcessor:
self._maybe_send_output(state, token)
elif state.status == RequestStatus.PREFILLING_SPLIT:
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
if self.cache.get_num_free_blocks() == 0:
raise ValueError("No more free blocks")
@traced
def has_pending_requests(self) -> bool:
@@ -1062,7 +1013,9 @@ class ContinuousBatchProcessor:
Args:
error: The error to report in the failure message
"""
for state in self.scheduler.active_requests.values():
requests = list(self.scheduler.active_requests.values())
for state in requests:
self._handle_request_error(error, state)
self.scheduler.finish_request(state.request_id)
@@ -1296,6 +1249,7 @@ class ContinuousBatchingManager:
self.generation_config,
self.model.device,
self.model.dtype,
num_requests=len(self.input_queue.queue),
tp_size=getattr(self.model, "tp_size"),
)
@@ -1324,30 +1278,7 @@ class ContinuousBatchingManager:
)
self.batch_processor = batch_processor
is_first = True
if self.profile:
tracing_schedule = schedule(skip_first=2, warmup=1, active=1, repeat=3, wait=1)
trace_handler = tensorboard_trace_handler(
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
)
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with profile(
activities=activities,
schedule=tracing_schedule,
on_trace_ready=trace_handler,
record_shapes=False,
with_stack=True,
) as prof:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
prof.step()
else:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
while (not self.stop_event.is_set()) or batch_processor.has_pending_requests():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
@@ -1363,6 +1294,8 @@ class ContinuousBatchingManager:
if torch.cuda.is_available():
torch.cuda.synchronize()
batch_processor.prepare_next_batch()
device, total, reserved, allocated = get_device_and_memory()
logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
if torch.cuda.is_available() and self.use_cuda_graph:
if is_first:
self.warmup(batch_processor)
@@ -1502,6 +1435,7 @@ class ContinuousMixin:
results[req_id] = result
finished_count += 1
pbar.update(1)
logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens))
else:
if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.")

View File

@@ -4,8 +4,11 @@ from ..generation.continuous_batching import PagedAttentionCache
from ..utils import is_flash_attn_2_available
if is_flash_attn_2_available():
try:
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func # noqa: F401
except Exception:
pass
def paged_attention_forward(

View File

@@ -2705,9 +2705,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
kernel_function = partial(attention_wrapper, implementation=kernel)
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function)
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
ALL_MASK_ATTENTION_FUNCTIONS.register(
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
except Exception as e:
logger.warning_once(
@@ -2715,8 +2715,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
"default attention implementation instead (sdpa if available, eager otherwise)."
)
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
return applicable_attn_implementation
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
return attn_implementation
else:
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)