|
|
|
|
@@ -13,9 +13,7 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import logging
|
|
|
|
|
import queue
|
|
|
|
|
import statistics
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
@@ -29,11 +27,11 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from tokenizers import Tokenizer
|
|
|
|
|
from tokenizers.decoders import DecodeStream
|
|
|
|
|
from torch.profiler import profile, schedule, tensorboard_trace_handler
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from ..configuration_utils import PretrainedConfig
|
|
|
|
|
from ..generation.configuration_utils import GenerationConfig
|
|
|
|
|
from ..utils.logging import logging
|
|
|
|
|
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -49,9 +47,7 @@ class RequestStatus(Enum):
|
|
|
|
|
FAILED = "failed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Setup your logger
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -159,8 +155,8 @@ class PagedAttentionCache:
|
|
|
|
|
generation_config: GenerationConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
dtype: torch.dtype = torch.float16,
|
|
|
|
|
num_requests: int = 100,
|
|
|
|
|
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
|
|
|
|
initial_prompt_shapes: Optional[list[list[int]]] = None,
|
|
|
|
|
tp_size: Optional[int] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Initialize a paged attention cache for efficient memory usage.
|
|
|
|
|
@@ -179,23 +175,6 @@ class PagedAttentionCache:
|
|
|
|
|
if getattr(config, "num_key_value_heads", None) is None
|
|
|
|
|
else config.num_key_value_heads
|
|
|
|
|
)
|
|
|
|
|
self.head_dim = (
|
|
|
|
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
|
|
|
)
|
|
|
|
|
self.num_hidden_layers = config.num_hidden_layers
|
|
|
|
|
|
|
|
|
|
# Calculate optimal block size and number if not provided
|
|
|
|
|
num_blocks = getattr(generation_config, "num_blocks", None)
|
|
|
|
|
block_size = getattr(generation_config, "block_size", None)
|
|
|
|
|
if num_blocks is None or block_size is None:
|
|
|
|
|
logger.info("Calculating optimal block size and number...")
|
|
|
|
|
num_blocks, block_size = compute_optimal_blocks(
|
|
|
|
|
device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}")
|
|
|
|
|
|
|
|
|
|
self.block_size = block_size
|
|
|
|
|
self.num_blocks = num_blocks
|
|
|
|
|
num_key_value_heads = self.num_key_value_heads
|
|
|
|
|
if tp_size is not None and tp_size > 1:
|
|
|
|
|
if num_key_value_heads % tp_size != 0:
|
|
|
|
|
@@ -203,8 +182,33 @@ class PagedAttentionCache:
|
|
|
|
|
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
|
|
|
|
)
|
|
|
|
|
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
|
|
|
|
num_key_value_heads //= tp_size
|
|
|
|
|
self.num_key_value_heads //= tp_size
|
|
|
|
|
|
|
|
|
|
self.head_dim = (
|
|
|
|
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
|
|
|
)
|
|
|
|
|
self.num_hidden_layers = config.num_hidden_layers
|
|
|
|
|
|
|
|
|
|
# Calculate optimal block size and number if not provided
|
|
|
|
|
num_blocks = getattr(generation_config, "num_blocks", None)
|
|
|
|
|
block_size = getattr(generation_config, "block_size", 32)
|
|
|
|
|
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
|
|
|
|
|
num_blocks, max_batch_tokens = compute_optimal_blocks(
|
|
|
|
|
generation_config.max_new_tokens,
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
head_dim=self.head_dim,
|
|
|
|
|
num_layers=self.num_hidden_layers,
|
|
|
|
|
num_heads=self.num_key_value_heads,
|
|
|
|
|
max_memory_percent=max_memory_percent,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
num_blocks=num_blocks,
|
|
|
|
|
)
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
|
|
|
|
|
)
|
|
|
|
|
self.max_batch_tokens = max_batch_tokens
|
|
|
|
|
self.block_size = block_size
|
|
|
|
|
self.num_blocks = num_blocks
|
|
|
|
|
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
|
|
|
|
|
|
|
|
|
|
self.dtype = dtype
|
|
|
|
|
@@ -249,7 +253,7 @@ class PagedAttentionCache:
|
|
|
|
|
blocks_to_free = self._block_tables.pop(request_id)
|
|
|
|
|
self._free_blocks.extend(blocks_to_free)
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}")
|
|
|
|
|
logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}")
|
|
|
|
|
|
|
|
|
|
def get_num_free_blocks(self) -> int:
|
|
|
|
|
"""Returns the number of free blocks available."""
|
|
|
|
|
@@ -343,7 +347,7 @@ class Scheduler(ABC):
|
|
|
|
|
@traced
|
|
|
|
|
def has_pending_requests(self) -> bool:
|
|
|
|
|
"""Check if there are requests ready to be processed."""
|
|
|
|
|
return self.active_requests or self.waiting_requests
|
|
|
|
|
return len(self.active_requests) or len(self.waiting_requests)
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
|
|
|
|
@@ -595,94 +599,60 @@ class PrefillFirstScheduler(Scheduler):
|
|
|
|
|
del self.active_requests[request_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_device_and_memory():
|
|
|
|
|
# Select best available device
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
device = torch.device("cuda")
|
|
|
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory
|
|
|
|
|
reserved_memory = torch.cuda.memory_reserved(device)
|
|
|
|
|
allocated_memory = torch.cuda.memory_allocated(device)
|
|
|
|
|
|
|
|
|
|
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
|
|
|
|
device = torch.device("mps")
|
|
|
|
|
# MPS memory reporting (PyTorch 2.0+)
|
|
|
|
|
total_memory = torch.mps.driver_allocated_memory()
|
|
|
|
|
allocated_memory = total_memory - torch.mps.recommended_max_memory()
|
|
|
|
|
reserved_memory = 0 # MPS does not track reserved separately
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
device = torch.device("cpu")
|
|
|
|
|
total_memory = None
|
|
|
|
|
reserved_memory = 0
|
|
|
|
|
allocated_memory = 0
|
|
|
|
|
|
|
|
|
|
return device, total_memory, reserved_memory, allocated_memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@traced(standalone=True)
|
|
|
|
|
def compute_optimal_blocks(
|
|
|
|
|
device: torch.device,
|
|
|
|
|
config: PretrainedConfig,
|
|
|
|
|
generation_config: GenerationConfig,
|
|
|
|
|
inputs: list[list[int]],
|
|
|
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
|
|
|
safety_margin: float = 0.9,
|
|
|
|
|
median_prefill_length: Optional[int] = None,
|
|
|
|
|
max_num_tokens,
|
|
|
|
|
block_size,
|
|
|
|
|
head_dim,
|
|
|
|
|
num_heads,
|
|
|
|
|
num_layers,
|
|
|
|
|
max_memory_percent=0.9,
|
|
|
|
|
num_blocks=None,
|
|
|
|
|
dtype=torch.float16,
|
|
|
|
|
):
|
|
|
|
|
"""Calculate optimal number and size of blocks for the KV cache.
|
|
|
|
|
device, total, reserved, allocated = get_device_and_memory()
|
|
|
|
|
available_memory = int((total - max(allocated, reserved)) * max_memory_percent)
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
device: The device where the model runs
|
|
|
|
|
config: The model configuration
|
|
|
|
|
generation_config: The generation configuration
|
|
|
|
|
inputs: Sample input sequences to estimate memory requirements
|
|
|
|
|
dtype: Data type for cache tensors
|
|
|
|
|
safety_margin: Fraction of available memory to use
|
|
|
|
|
median_prefill_length: Override for median prefill length calculation
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (num_blocks, block_size)
|
|
|
|
|
"""
|
|
|
|
|
# Extract model dimensions
|
|
|
|
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
|
|
|
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
|
|
|
|
|
num_hidden_layers = getattr(config, "num_hidden_layers", 40)
|
|
|
|
|
|
|
|
|
|
# Get available device memory
|
|
|
|
|
if device.type == "cuda":
|
|
|
|
|
device_properties = torch.cuda.get_device_properties(device)
|
|
|
|
|
total_memory = device_properties.total_memory
|
|
|
|
|
allocated_memory = torch.cuda.memory_allocated(device)
|
|
|
|
|
reserved_memory = torch.cuda.memory_reserved(device)
|
|
|
|
|
available_memory = total_memory - max(allocated_memory, reserved_memory)
|
|
|
|
|
elif device.type == "mps":
|
|
|
|
|
logger.warning("MPS memory estimation is approximate. Using conservative defaults.")
|
|
|
|
|
return 2048, 256
|
|
|
|
|
else:
|
|
|
|
|
logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.")
|
|
|
|
|
return 32, 128
|
|
|
|
|
|
|
|
|
|
# Apply safety margin
|
|
|
|
|
available_memory = int(available_memory * safety_margin)
|
|
|
|
|
if available_memory <= 0:
|
|
|
|
|
logger.warning("Not enough available memory. Using minimum configuration.")
|
|
|
|
|
return 8, 128 # Minimum viable configuration
|
|
|
|
|
|
|
|
|
|
# Calculate memory per token
|
|
|
|
|
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
|
|
|
|
memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches
|
|
|
|
|
|
|
|
|
|
# Estimate sequence length requirements
|
|
|
|
|
tokens_to_generate = getattr(generation_config, "max_new_tokens") or 20
|
|
|
|
|
|
|
|
|
|
if median_prefill_length is None and inputs:
|
|
|
|
|
non_empty_inputs = [len(seq) for seq in inputs if seq]
|
|
|
|
|
median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64
|
|
|
|
|
elif median_prefill_length is None:
|
|
|
|
|
median_prefill_length = 64 # Reasonable default if no inputs provided
|
|
|
|
|
|
|
|
|
|
# Total sequence length including generated tokens
|
|
|
|
|
seq_length = median_prefill_length + tokens_to_generate
|
|
|
|
|
|
|
|
|
|
# Calculate block parameters
|
|
|
|
|
MIN_BLOCK_SIZE = 16
|
|
|
|
|
|
|
|
|
|
# Estimate number of concurrent sequences
|
|
|
|
|
per_sequence_memory = seq_length * memory_per_token
|
|
|
|
|
max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory))
|
|
|
|
|
|
|
|
|
|
# Total tokens that can fit in memory
|
|
|
|
|
total_tokens = available_memory // memory_per_token
|
|
|
|
|
|
|
|
|
|
# Calculate block size (rounded to power of 2)
|
|
|
|
|
initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2))
|
|
|
|
|
block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2
|
|
|
|
|
|
|
|
|
|
# Calculate number of blocks
|
|
|
|
|
num_blocks = max(1, total_tokens // block_size)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Optimal cache: {num_blocks} blocks of size {block_size} "
|
|
|
|
|
f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})"
|
|
|
|
|
bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers
|
|
|
|
|
if num_blocks is not None:
|
|
|
|
|
# TODO
|
|
|
|
|
max_possible_concurrent_requests = num_blocks * bytes_per_token
|
|
|
|
|
# FIXME: forgot to add the inintial prompt length in the mix....
|
|
|
|
|
max_possible_concurrent_requests = int(
|
|
|
|
|
available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return int(num_blocks), int(block_size)
|
|
|
|
|
if max_possible_concurrent_requests <= 0:
|
|
|
|
|
logger.warning("you are trying to generate a bit too many tokens")
|
|
|
|
|
max_possible_concurrent_requests = 32
|
|
|
|
|
max_concurrent_tokens = min(64, max_possible_concurrent_requests)
|
|
|
|
|
# FIXME: Optimal means uses all memory
|
|
|
|
|
optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64)
|
|
|
|
|
return optimal_num_blocks, max_concurrent_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -775,11 +745,9 @@ class ContinuousBatchProcessor:
|
|
|
|
|
|
|
|
|
|
self.requests_in_batch: list[RequestState] = []
|
|
|
|
|
|
|
|
|
|
# Get batch size parameters from generation config
|
|
|
|
|
self._configure_batch_parameters()
|
|
|
|
|
|
|
|
|
|
# Set up metrics collector
|
|
|
|
|
self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens)
|
|
|
|
|
self.max_batch_tokens = cache.max_batch_tokens
|
|
|
|
|
self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
|
|
|
|
|
|
|
|
|
|
self.setup_static_tensors()
|
|
|
|
|
|
|
|
|
|
@@ -847,25 +815,6 @@ class ContinuousBatchProcessor:
|
|
|
|
|
+ self.get_model_kwargs().__repr__()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@traced(standalone=True)
|
|
|
|
|
def _configure_batch_parameters(self):
|
|
|
|
|
"""Set up batch processing parameters based on generation config."""
|
|
|
|
|
# Calculate total cache capacity
|
|
|
|
|
total_cache_tokens = self.cache.num_blocks * self.cache.block_size
|
|
|
|
|
|
|
|
|
|
# Get or calculate max tokens per batch
|
|
|
|
|
user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None)
|
|
|
|
|
if user_batch_tokens is not None:
|
|
|
|
|
self.max_batch_tokens = user_batch_tokens
|
|
|
|
|
else:
|
|
|
|
|
# Default to 1/8 of total cache capacity, adjusted for context
|
|
|
|
|
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
|
|
|
|
|
recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len)
|
|
|
|
|
self.max_batch_tokens = max(64, recommended_batch_size)
|
|
|
|
|
|
|
|
|
|
# Context length and EOS token
|
|
|
|
|
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
|
|
|
|
|
|
|
|
|
|
@traced
|
|
|
|
|
def _get_new_requests(self):
|
|
|
|
|
"""Pull new requests from the input queue and add to waiting list."""
|
|
|
|
|
@@ -1041,6 +990,8 @@ class ContinuousBatchProcessor:
|
|
|
|
|
self._maybe_send_output(state, token)
|
|
|
|
|
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
|
|
|
|
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
|
|
|
|
if self.cache.get_num_free_blocks() == 0:
|
|
|
|
|
raise ValueError("No more free blocks")
|
|
|
|
|
|
|
|
|
|
@traced
|
|
|
|
|
def has_pending_requests(self) -> bool:
|
|
|
|
|
@@ -1062,7 +1013,9 @@ class ContinuousBatchProcessor:
|
|
|
|
|
Args:
|
|
|
|
|
error: The error to report in the failure message
|
|
|
|
|
"""
|
|
|
|
|
for state in self.scheduler.active_requests.values():
|
|
|
|
|
|
|
|
|
|
requests = list(self.scheduler.active_requests.values())
|
|
|
|
|
for state in requests:
|
|
|
|
|
self._handle_request_error(error, state)
|
|
|
|
|
self.scheduler.finish_request(state.request_id)
|
|
|
|
|
|
|
|
|
|
@@ -1296,6 +1249,7 @@ class ContinuousBatchingManager:
|
|
|
|
|
self.generation_config,
|
|
|
|
|
self.model.device,
|
|
|
|
|
self.model.dtype,
|
|
|
|
|
num_requests=len(self.input_queue.queue),
|
|
|
|
|
tp_size=getattr(self.model, "tp_size"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1324,30 +1278,7 @@ class ContinuousBatchingManager:
|
|
|
|
|
)
|
|
|
|
|
self.batch_processor = batch_processor
|
|
|
|
|
is_first = True
|
|
|
|
|
|
|
|
|
|
if self.profile:
|
|
|
|
|
tracing_schedule = schedule(skip_first=2, warmup=1, active=1, repeat=3, wait=1)
|
|
|
|
|
trace_handler = tensorboard_trace_handler(
|
|
|
|
|
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
|
|
|
|
|
)
|
|
|
|
|
activities = [
|
|
|
|
|
torch.profiler.ProfilerActivity.CPU,
|
|
|
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
|
|
|
]
|
|
|
|
|
with profile(
|
|
|
|
|
activities=activities,
|
|
|
|
|
schedule=tracing_schedule,
|
|
|
|
|
on_trace_ready=trace_handler,
|
|
|
|
|
record_shapes=False,
|
|
|
|
|
with_stack=True,
|
|
|
|
|
) as prof:
|
|
|
|
|
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
|
|
|
|
|
self._inner_generation_loop(batch_processor, is_first)
|
|
|
|
|
if is_first:
|
|
|
|
|
is_first = False
|
|
|
|
|
prof.step()
|
|
|
|
|
else:
|
|
|
|
|
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
|
|
|
|
|
while (not self.stop_event.is_set()) or batch_processor.has_pending_requests():
|
|
|
|
|
self._inner_generation_loop(batch_processor, is_first)
|
|
|
|
|
if is_first:
|
|
|
|
|
is_first = False
|
|
|
|
|
@@ -1363,6 +1294,8 @@ class ContinuousBatchingManager:
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
batch_processor.prepare_next_batch()
|
|
|
|
|
device, total, reserved, allocated = get_device_and_memory()
|
|
|
|
|
logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
|
|
|
|
|
if torch.cuda.is_available() and self.use_cuda_graph:
|
|
|
|
|
if is_first:
|
|
|
|
|
self.warmup(batch_processor)
|
|
|
|
|
@@ -1502,6 +1435,7 @@ class ContinuousMixin:
|
|
|
|
|
results[req_id] = result
|
|
|
|
|
finished_count += 1
|
|
|
|
|
pbar.update(1)
|
|
|
|
|
logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens))
|
|
|
|
|
else:
|
|
|
|
|
if not manager.is_running():
|
|
|
|
|
logger.error("Generation thread terminated unexpectedly.")
|
|
|
|
|
|