From 6ea646a03ab6f7ac330218815dfe0941a799c343 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 1 Aug 2025 16:50:28 +0200 Subject: [PATCH] Update ux cb (#39845) * clenaup * nits * updates * fix logging * push updates? * just passexception * update * nits * fix * add tokencount * style --- examples/pytorch/continuous_batching.py | 33 ++- .../generation/continuous_batching.py | 252 +++++++----------- src/transformers/integrations/flash_paged.py | 7 +- src/transformers/modeling_utils.py | 8 +- 4 files changed, 121 insertions(+), 179 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index 9aaa836f7b..75c1331907 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -10,26 +10,27 @@ from transformers.generation import GenerationConfig torch.set_float32_matmul_precision("high") model_id = "meta-llama/Llama-3.2-3b-Instruct" -model = AutoModelForCausalLM.from_pretrained( - model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" -).eval() +model = ( + AutoModelForCausalLM.from_pretrained( + model_id, + attn_implementation="paged_attention|kernels-community/flash-attn", + torch_dtype=torch.bfloat16, + ) + .eval() + .cuda() +) tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") generation_config = GenerationConfig( max_new_tokens=512, + # use_cuda_graph=False, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, - use_cache=False, - num_blocks=2048, - block_size=128, - do_sample=True, - max_batch_tokens=1024, # Maximum number of tokens to process in a single batch - scheduler="prefill_first", + do_sample=False, ) train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") - -# --- Example 1: Simple Version using generate_batch --- +train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version print("--- Running CB Generation Example ---") @@ -41,19 +42,21 @@ tokenized_datasets = train_dataset.map(tokenize_function, batched=True) simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] start_time_simple = time.time() -# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True) +model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs") batch_outputs = model.generate_batch( inputs=simple_batch_inputs, generation_config=generation_config, ) end_time_simple = time.time() - +token_count = 0 for request in batch_outputs: input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) try: output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) + token_count += len(batch_outputs[request].generated_tokens[1:]) except Exception as e: print(f"Decoding failed for request {request}: {e}") + token_count += len(batch_outputs[request].generated_tokens[1:]) output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) if len(output_text) > 0: print("-" * 20) @@ -65,7 +68,9 @@ print("-" * 20) print("--- Finished CB Generation Example ---\n\n") -print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") +print( + f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s" +) # train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index 482e28bccc..a43b11fb40 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -13,9 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import queue -import statistics import threading import time from abc import ABC, abstractmethod @@ -29,11 +27,11 @@ import torch import torch.nn as nn from tokenizers import Tokenizer from tokenizers.decoders import DecodeStream -from torch.profiler import profile, schedule, tensorboard_trace_handler from tqdm import tqdm from ..configuration_utils import PretrainedConfig from ..generation.configuration_utils import GenerationConfig +from ..utils.logging import logging from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced @@ -49,9 +47,7 @@ class RequestStatus(Enum): FAILED = "failed" -# Setup your logger logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) @dataclass @@ -159,8 +155,8 @@ class PagedAttentionCache: generation_config: GenerationConfig, device: torch.device, dtype: torch.dtype = torch.float16, + num_requests: int = 100, layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None, - initial_prompt_shapes: Optional[list[list[int]]] = None, tp_size: Optional[int] = None, ) -> None: """Initialize a paged attention cache for efficient memory usage. @@ -179,23 +175,6 @@ class PagedAttentionCache: if getattr(config, "num_key_value_heads", None) is None else config.num_key_value_heads ) - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) - self.num_hidden_layers = config.num_hidden_layers - - # Calculate optimal block size and number if not provided - num_blocks = getattr(generation_config, "num_blocks", None) - block_size = getattr(generation_config, "block_size", None) - if num_blocks is None or block_size is None: - logger.info("Calculating optimal block size and number...") - num_blocks, block_size = compute_optimal_blocks( - device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200 - ) - logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}") - - self.block_size = block_size - self.num_blocks = num_blocks num_key_value_heads = self.num_key_value_heads if tp_size is not None and tp_size > 1: if num_key_value_heads % tp_size != 0: @@ -203,8 +182,33 @@ class PagedAttentionCache: f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}." ) # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - num_key_value_heads //= tp_size + self.num_key_value_heads //= tp_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.num_hidden_layers = config.num_hidden_layers + + # Calculate optimal block size and number if not provided + num_blocks = getattr(generation_config, "num_blocks", None) + block_size = getattr(generation_config, "block_size", 32) + max_memory_percent = getattr(generation_config, "max_memory", 0.9) + num_blocks, max_batch_tokens = compute_optimal_blocks( + generation_config.max_new_tokens, + block_size=block_size, + head_dim=self.head_dim, + num_layers=self.num_hidden_layers, + num_heads=self.num_key_value_heads, + max_memory_percent=max_memory_percent, + dtype=dtype, + num_blocks=num_blocks, + ) + logger.warning( + f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}" + ) + self.max_batch_tokens = max_batch_tokens + self.block_size = block_size + self.num_blocks = num_blocks self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim) self.dtype = dtype @@ -249,7 +253,7 @@ class PagedAttentionCache: blocks_to_free = self._block_tables.pop(request_id) self._free_blocks.extend(blocks_to_free) else: - logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}") + logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}") def get_num_free_blocks(self) -> int: """Returns the number of free blocks available.""" @@ -343,7 +347,7 @@ class Scheduler(ABC): @traced def has_pending_requests(self) -> bool: """Check if there are requests ready to be processed.""" - return self.active_requests or self.waiting_requests + return len(self.active_requests) or len(self.waiting_requests) @abstractmethod def finish_request(self, request_id: str, evict_from_cache: bool = True): @@ -595,94 +599,60 @@ class PrefillFirstScheduler(Scheduler): del self.active_requests[request_id] +def get_device_and_memory(): + # Select best available device + if torch.cuda.is_available(): + device = torch.device("cuda") + total_memory = torch.cuda.get_device_properties(device).total_memory + reserved_memory = torch.cuda.memory_reserved(device) + allocated_memory = torch.cuda.memory_allocated(device) + + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = torch.device("mps") + # MPS memory reporting (PyTorch 2.0+) + total_memory = torch.mps.driver_allocated_memory() + allocated_memory = total_memory - torch.mps.recommended_max_memory() + reserved_memory = 0 # MPS does not track reserved separately + + else: + device = torch.device("cpu") + total_memory = None + reserved_memory = 0 + allocated_memory = 0 + + return device, total_memory, reserved_memory, allocated_memory + + @traced(standalone=True) def compute_optimal_blocks( - device: torch.device, - config: PretrainedConfig, - generation_config: GenerationConfig, - inputs: list[list[int]], - dtype: torch.dtype = torch.bfloat16, - safety_margin: float = 0.9, - median_prefill_length: Optional[int] = None, + max_num_tokens, + block_size, + head_dim, + num_heads, + num_layers, + max_memory_percent=0.9, + num_blocks=None, + dtype=torch.float16, ): - """Calculate optimal number and size of blocks for the KV cache. + device, total, reserved, allocated = get_device_and_memory() + available_memory = int((total - max(allocated, reserved)) * max_memory_percent) - Args: - device: The device where the model runs - config: The model configuration - generation_config: The generation configuration - inputs: Sample input sequences to estimate memory requirements - dtype: Data type for cache tensors - safety_margin: Fraction of available memory to use - median_prefill_length: Override for median prefill length calculation - - Returns: - Tuple of (num_blocks, block_size) - """ - # Extract model dimensions - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - num_hidden_layers = getattr(config, "num_hidden_layers", 40) - - # Get available device memory - if device.type == "cuda": - device_properties = torch.cuda.get_device_properties(device) - total_memory = device_properties.total_memory - allocated_memory = torch.cuda.memory_allocated(device) - reserved_memory = torch.cuda.memory_reserved(device) - available_memory = total_memory - max(allocated_memory, reserved_memory) - elif device.type == "mps": - logger.warning("MPS memory estimation is approximate. Using conservative defaults.") - return 2048, 256 - else: - logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.") - return 32, 128 - - # Apply safety margin - available_memory = int(available_memory * safety_margin) - if available_memory <= 0: - logger.warning("Not enough available memory. Using minimum configuration.") - return 8, 128 # Minimum viable configuration - - # Calculate memory per token dtype_size = torch.tensor([], dtype=dtype).element_size() - memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches - - # Estimate sequence length requirements - tokens_to_generate = getattr(generation_config, "max_new_tokens") or 20 - - if median_prefill_length is None and inputs: - non_empty_inputs = [len(seq) for seq in inputs if seq] - median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64 - elif median_prefill_length is None: - median_prefill_length = 64 # Reasonable default if no inputs provided - - # Total sequence length including generated tokens - seq_length = median_prefill_length + tokens_to_generate - - # Calculate block parameters - MIN_BLOCK_SIZE = 16 - - # Estimate number of concurrent sequences - per_sequence_memory = seq_length * memory_per_token - max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory)) - - # Total tokens that can fit in memory - total_tokens = available_memory // memory_per_token - - # Calculate block size (rounded to power of 2) - initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2)) - block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2 - - # Calculate number of blocks - num_blocks = max(1, total_tokens // block_size) - - logger.info( - f"Optimal cache: {num_blocks} blocks of size {block_size} " - f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})" + bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers + if num_blocks is not None: + # TODO + max_possible_concurrent_requests = num_blocks * bytes_per_token + # FIXME: forgot to add the inintial prompt length in the mix.... + max_possible_concurrent_requests = int( + available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4) ) - - return int(num_blocks), int(block_size) + if max_possible_concurrent_requests <= 0: + logger.warning("you are trying to generate a bit too many tokens") + max_possible_concurrent_requests = 32 + max_concurrent_tokens = min(64, max_possible_concurrent_requests) + # FIXME: Optimal means uses all memory + optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64) + return optimal_num_blocks, max_concurrent_tokens @dataclass @@ -775,11 +745,9 @@ class ContinuousBatchProcessor: self.requests_in_batch: list[RequestState] = [] - # Get batch size parameters from generation config - self._configure_batch_parameters() - # Set up metrics collector - self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens) + self.max_batch_tokens = cache.max_batch_tokens + self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens) self.setup_static_tensors() @@ -847,25 +815,6 @@ class ContinuousBatchProcessor: + self.get_model_kwargs().__repr__() ) - @traced(standalone=True) - def _configure_batch_parameters(self): - """Set up batch processing parameters based on generation config.""" - # Calculate total cache capacity - total_cache_tokens = self.cache.num_blocks * self.cache.block_size - - # Get or calculate max tokens per batch - user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None) - if user_batch_tokens is not None: - self.max_batch_tokens = user_batch_tokens - else: - # Default to 1/8 of total cache capacity, adjusted for context - self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048) - recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len) - self.max_batch_tokens = max(64, recommended_batch_size) - - # Context length and EOS token - self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048) - @traced def _get_new_requests(self): """Pull new requests from the input queue and add to waiting list.""" @@ -1041,6 +990,8 @@ class ContinuousBatchProcessor: self._maybe_send_output(state, token) elif state.status == RequestStatus.PREFILLING_SPLIT: state.status = RequestStatus.SPLIT_PENDING_REMAINDER + if self.cache.get_num_free_blocks() == 0: + raise ValueError("No more free blocks") @traced def has_pending_requests(self) -> bool: @@ -1062,7 +1013,9 @@ class ContinuousBatchProcessor: Args: error: The error to report in the failure message """ - for state in self.scheduler.active_requests.values(): + + requests = list(self.scheduler.active_requests.values()) + for state in requests: self._handle_request_error(error, state) self.scheduler.finish_request(state.request_id) @@ -1296,6 +1249,7 @@ class ContinuousBatchingManager: self.generation_config, self.model.device, self.model.dtype, + num_requests=len(self.input_queue.queue), tp_size=getattr(self.model, "tp_size"), ) @@ -1324,33 +1278,10 @@ class ContinuousBatchingManager: ) self.batch_processor = batch_processor is_first = True - - if self.profile: - tracing_schedule = schedule(skip_first=2, warmup=1, active=1, repeat=3, wait=1) - trace_handler = tensorboard_trace_handler( - dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile" - ) - activities = [ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ] - with profile( - activities=activities, - schedule=tracing_schedule, - on_trace_ready=trace_handler, - record_shapes=False, - with_stack=True, - ) as prof: - while not self.stop_event.is_set() or batch_processor.has_pending_requests(): - self._inner_generation_loop(batch_processor, is_first) - if is_first: - is_first = False - prof.step() - else: - while not self.stop_event.is_set() or batch_processor.has_pending_requests(): - self._inner_generation_loop(batch_processor, is_first) - if is_first: - is_first = False + while (not self.stop_event.is_set()) or batch_processor.has_pending_requests(): + self._inner_generation_loop(batch_processor, is_first) + if is_first: + is_first = False except Exception as e: logger.error(f"Error in generation loop: {e}", exc_info=True) @@ -1363,6 +1294,8 @@ class ContinuousBatchingManager: if torch.cuda.is_available(): torch.cuda.synchronize() batch_processor.prepare_next_batch() + device, total, reserved, allocated = get_device_and_memory() + logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}") if torch.cuda.is_available() and self.use_cuda_graph: if is_first: self.warmup(batch_processor) @@ -1502,6 +1435,7 @@ class ContinuousMixin: results[req_id] = result finished_count += 1 pbar.update(1) + logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens)) else: if not manager.is_running(): logger.error("Generation thread terminated unexpectedly.") diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 5c91714486..a7bf5ae577 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -4,8 +4,11 @@ from ..generation.continuous_batching import PagedAttentionCache from ..utils import is_flash_attn_2_available -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func # noqa: F401 +try: + if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func # noqa: F401 +except Exception: + pass def paged_attention_forward( diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5a526e5dd9..03e9cf5314 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2705,9 +2705,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH kernel_function = partial(attention_wrapper, implementation=kernel) elif kernel_name is not None: kernel_function = getattr(kernel, kernel_name) - ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function) + ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) ALL_MASK_ATTENTION_FUNCTIONS.register( - applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] ) except Exception as e: logger.warning_once( @@ -2715,8 +2715,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH "default attention implementation instead (sdpa if available, eager otherwise)." ) - applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case - return applicable_attn_implementation + attn_implementation = "sdpa" # Try to fallback to sdpa in this case + return attn_implementation else: return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)