Update ux cb (#39845)
* clenaup * nits * updates * fix logging * push updates? * just passexception * update * nits * fix * add tokencount * style
This commit is contained in:
@@ -10,26 +10,27 @@ from transformers.generation import GenerationConfig
|
|||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
model_id = "meta-llama/Llama-3.2-3b-Instruct"
|
model_id = "meta-llama/Llama-3.2-3b-Instruct"
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = (
|
||||||
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
|
AutoModelForCausalLM.from_pretrained(
|
||||||
).eval()
|
model_id,
|
||||||
|
attn_implementation="paged_attention|kernels-community/flash-attn",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||||
|
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
max_new_tokens=512,
|
max_new_tokens=512,
|
||||||
|
# use_cuda_graph=False,
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
use_cache=False,
|
do_sample=False,
|
||||||
num_blocks=2048,
|
|
||||||
block_size=128,
|
|
||||||
do_sample=True,
|
|
||||||
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
|
|
||||||
scheduler="prefill_first",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||||
|
train_dataset = train_dataset.select(range(500)) # Use only 5 examples for the simple version
|
||||||
# --- Example 1: Simple Version using generate_batch ---
|
|
||||||
print("--- Running CB Generation Example ---")
|
print("--- Running CB Generation Example ---")
|
||||||
|
|
||||||
|
|
||||||
@@ -41,19 +42,21 @@ tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
|
|||||||
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
|
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
|
||||||
|
|
||||||
start_time_simple = time.time()
|
start_time_simple = time.time()
|
||||||
# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True)
|
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||||
batch_outputs = model.generate_batch(
|
batch_outputs = model.generate_batch(
|
||||||
inputs=simple_batch_inputs,
|
inputs=simple_batch_inputs,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
)
|
)
|
||||||
end_time_simple = time.time()
|
end_time_simple = time.time()
|
||||||
|
token_count = 0
|
||||||
for request in batch_outputs:
|
for request in batch_outputs:
|
||||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
||||||
try:
|
try:
|
||||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
|
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
|
||||||
|
token_count += len(batch_outputs[request].generated_tokens[1:])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Decoding failed for request {request}: {e}")
|
print(f"Decoding failed for request {request}: {e}")
|
||||||
|
token_count += len(batch_outputs[request].generated_tokens[1:])
|
||||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
|
output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False)
|
||||||
if len(output_text) > 0:
|
if len(output_text) > 0:
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
@@ -65,7 +68,9 @@ print("-" * 20)
|
|||||||
print("--- Finished CB Generation Example ---\n\n")
|
print("--- Finished CB Generation Example ---\n\n")
|
||||||
|
|
||||||
|
|
||||||
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
|
print(
|
||||||
|
f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {token_count} tokens. {token_count / (end_time_simple - start_time_simple)}tok/s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
|
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
|
||||||
|
|||||||
@@ -13,9 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
import queue
|
import queue
|
||||||
import statistics
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -29,11 +27,11 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from tokenizers.decoders import DecodeStream
|
from tokenizers.decoders import DecodeStream
|
||||||
from torch.profiler import profile, schedule, tensorboard_trace_handler
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
from ..generation.configuration_utils import GenerationConfig
|
from ..generation.configuration_utils import GenerationConfig
|
||||||
|
from ..utils.logging import logging
|
||||||
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
|
||||||
|
|
||||||
|
|
||||||
@@ -49,9 +47,7 @@ class RequestStatus(Enum):
|
|||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
# Setup your logger
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -159,8 +155,8 @@ class PagedAttentionCache:
|
|||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
|
num_requests: int = 100,
|
||||||
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
|
||||||
initial_prompt_shapes: Optional[list[list[int]]] = None,
|
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a paged attention cache for efficient memory usage.
|
"""Initialize a paged attention cache for efficient memory usage.
|
||||||
@@ -179,23 +175,6 @@ class PagedAttentionCache:
|
|||||||
if getattr(config, "num_key_value_heads", None) is None
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
else config.num_key_value_heads
|
else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
self.head_dim = (
|
|
||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
||||||
)
|
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
|
||||||
|
|
||||||
# Calculate optimal block size and number if not provided
|
|
||||||
num_blocks = getattr(generation_config, "num_blocks", None)
|
|
||||||
block_size = getattr(generation_config, "block_size", None)
|
|
||||||
if num_blocks is None or block_size is None:
|
|
||||||
logger.info("Calculating optimal block size and number...")
|
|
||||||
num_blocks, block_size = compute_optimal_blocks(
|
|
||||||
device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200
|
|
||||||
)
|
|
||||||
logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}")
|
|
||||||
|
|
||||||
self.block_size = block_size
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
num_key_value_heads = self.num_key_value_heads
|
num_key_value_heads = self.num_key_value_heads
|
||||||
if tp_size is not None and tp_size > 1:
|
if tp_size is not None and tp_size > 1:
|
||||||
if num_key_value_heads % tp_size != 0:
|
if num_key_value_heads % tp_size != 0:
|
||||||
@@ -203,8 +182,33 @@ class PagedAttentionCache:
|
|||||||
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
|
||||||
)
|
)
|
||||||
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
||||||
num_key_value_heads //= tp_size
|
self.num_key_value_heads //= tp_size
|
||||||
|
|
||||||
|
self.head_dim = (
|
||||||
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|
||||||
|
# Calculate optimal block size and number if not provided
|
||||||
|
num_blocks = getattr(generation_config, "num_blocks", None)
|
||||||
|
block_size = getattr(generation_config, "block_size", 32)
|
||||||
|
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
|
||||||
|
num_blocks, max_batch_tokens = compute_optimal_blocks(
|
||||||
|
generation_config.max_new_tokens,
|
||||||
|
block_size=block_size,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
num_layers=self.num_hidden_layers,
|
||||||
|
num_heads=self.num_key_value_heads,
|
||||||
|
max_memory_percent=max_memory_percent,
|
||||||
|
dtype=dtype,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
|
||||||
|
)
|
||||||
|
self.max_batch_tokens = max_batch_tokens
|
||||||
|
self.block_size = block_size
|
||||||
|
self.num_blocks = num_blocks
|
||||||
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
|
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@@ -249,7 +253,7 @@ class PagedAttentionCache:
|
|||||||
blocks_to_free = self._block_tables.pop(request_id)
|
blocks_to_free = self._block_tables.pop(request_id)
|
||||||
self._free_blocks.extend(blocks_to_free)
|
self._free_blocks.extend(blocks_to_free)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}")
|
logger.info(f"Attempted to free blocks for non-existent request_id: {request_id}")
|
||||||
|
|
||||||
def get_num_free_blocks(self) -> int:
|
def get_num_free_blocks(self) -> int:
|
||||||
"""Returns the number of free blocks available."""
|
"""Returns the number of free blocks available."""
|
||||||
@@ -343,7 +347,7 @@ class Scheduler(ABC):
|
|||||||
@traced
|
@traced
|
||||||
def has_pending_requests(self) -> bool:
|
def has_pending_requests(self) -> bool:
|
||||||
"""Check if there are requests ready to be processed."""
|
"""Check if there are requests ready to be processed."""
|
||||||
return self.active_requests or self.waiting_requests
|
return len(self.active_requests) or len(self.waiting_requests)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
def finish_request(self, request_id: str, evict_from_cache: bool = True):
|
||||||
@@ -595,94 +599,60 @@ class PrefillFirstScheduler(Scheduler):
|
|||||||
del self.active_requests[request_id]
|
del self.active_requests[request_id]
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_and_memory():
|
||||||
|
# Select best available device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
total_memory = torch.cuda.get_device_properties(device).total_memory
|
||||||
|
reserved_memory = torch.cuda.memory_reserved(device)
|
||||||
|
allocated_memory = torch.cuda.memory_allocated(device)
|
||||||
|
|
||||||
|
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||||
|
device = torch.device("mps")
|
||||||
|
# MPS memory reporting (PyTorch 2.0+)
|
||||||
|
total_memory = torch.mps.driver_allocated_memory()
|
||||||
|
allocated_memory = total_memory - torch.mps.recommended_max_memory()
|
||||||
|
reserved_memory = 0 # MPS does not track reserved separately
|
||||||
|
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
total_memory = None
|
||||||
|
reserved_memory = 0
|
||||||
|
allocated_memory = 0
|
||||||
|
|
||||||
|
return device, total_memory, reserved_memory, allocated_memory
|
||||||
|
|
||||||
|
|
||||||
@traced(standalone=True)
|
@traced(standalone=True)
|
||||||
def compute_optimal_blocks(
|
def compute_optimal_blocks(
|
||||||
device: torch.device,
|
max_num_tokens,
|
||||||
config: PretrainedConfig,
|
block_size,
|
||||||
generation_config: GenerationConfig,
|
head_dim,
|
||||||
inputs: list[list[int]],
|
num_heads,
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
num_layers,
|
||||||
safety_margin: float = 0.9,
|
max_memory_percent=0.9,
|
||||||
median_prefill_length: Optional[int] = None,
|
num_blocks=None,
|
||||||
|
dtype=torch.float16,
|
||||||
):
|
):
|
||||||
"""Calculate optimal number and size of blocks for the KV cache.
|
device, total, reserved, allocated = get_device_and_memory()
|
||||||
|
available_memory = int((total - max(allocated, reserved)) * max_memory_percent)
|
||||||
|
|
||||||
Args:
|
|
||||||
device: The device where the model runs
|
|
||||||
config: The model configuration
|
|
||||||
generation_config: The generation configuration
|
|
||||||
inputs: Sample input sequences to estimate memory requirements
|
|
||||||
dtype: Data type for cache tensors
|
|
||||||
safety_margin: Fraction of available memory to use
|
|
||||||
median_prefill_length: Override for median prefill length calculation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (num_blocks, block_size)
|
|
||||||
"""
|
|
||||||
# Extract model dimensions
|
|
||||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
||||||
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
|
|
||||||
num_hidden_layers = getattr(config, "num_hidden_layers", 40)
|
|
||||||
|
|
||||||
# Get available device memory
|
|
||||||
if device.type == "cuda":
|
|
||||||
device_properties = torch.cuda.get_device_properties(device)
|
|
||||||
total_memory = device_properties.total_memory
|
|
||||||
allocated_memory = torch.cuda.memory_allocated(device)
|
|
||||||
reserved_memory = torch.cuda.memory_reserved(device)
|
|
||||||
available_memory = total_memory - max(allocated_memory, reserved_memory)
|
|
||||||
elif device.type == "mps":
|
|
||||||
logger.warning("MPS memory estimation is approximate. Using conservative defaults.")
|
|
||||||
return 2048, 256
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.")
|
|
||||||
return 32, 128
|
|
||||||
|
|
||||||
# Apply safety margin
|
|
||||||
available_memory = int(available_memory * safety_margin)
|
|
||||||
if available_memory <= 0:
|
|
||||||
logger.warning("Not enough available memory. Using minimum configuration.")
|
|
||||||
return 8, 128 # Minimum viable configuration
|
|
||||||
|
|
||||||
# Calculate memory per token
|
|
||||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches
|
bytes_per_token = 2 * num_heads * head_dim * dtype_size * num_layers
|
||||||
|
if num_blocks is not None:
|
||||||
# Estimate sequence length requirements
|
# TODO
|
||||||
tokens_to_generate = getattr(generation_config, "max_new_tokens") or 20
|
max_possible_concurrent_requests = num_blocks * bytes_per_token
|
||||||
|
# FIXME: forgot to add the inintial prompt length in the mix....
|
||||||
if median_prefill_length is None and inputs:
|
max_possible_concurrent_requests = int(
|
||||||
non_empty_inputs = [len(seq) for seq in inputs if seq]
|
available_memory // (bytes_per_token * max_num_tokens * max_num_tokens // 4)
|
||||||
median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64
|
|
||||||
elif median_prefill_length is None:
|
|
||||||
median_prefill_length = 64 # Reasonable default if no inputs provided
|
|
||||||
|
|
||||||
# Total sequence length including generated tokens
|
|
||||||
seq_length = median_prefill_length + tokens_to_generate
|
|
||||||
|
|
||||||
# Calculate block parameters
|
|
||||||
MIN_BLOCK_SIZE = 16
|
|
||||||
|
|
||||||
# Estimate number of concurrent sequences
|
|
||||||
per_sequence_memory = seq_length * memory_per_token
|
|
||||||
max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory))
|
|
||||||
|
|
||||||
# Total tokens that can fit in memory
|
|
||||||
total_tokens = available_memory // memory_per_token
|
|
||||||
|
|
||||||
# Calculate block size (rounded to power of 2)
|
|
||||||
initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2))
|
|
||||||
block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2
|
|
||||||
|
|
||||||
# Calculate number of blocks
|
|
||||||
num_blocks = max(1, total_tokens // block_size)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Optimal cache: {num_blocks} blocks of size {block_size} "
|
|
||||||
f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})"
|
|
||||||
)
|
)
|
||||||
|
if max_possible_concurrent_requests <= 0:
|
||||||
return int(num_blocks), int(block_size)
|
logger.warning("you are trying to generate a bit too many tokens")
|
||||||
|
max_possible_concurrent_requests = 32
|
||||||
|
max_concurrent_tokens = min(64, max_possible_concurrent_requests)
|
||||||
|
# FIXME: Optimal means uses all memory
|
||||||
|
optimal_num_blocks = max(((max_concurrent_tokens * max_num_tokens) // block_size) + 1, 64)
|
||||||
|
return optimal_num_blocks, max_concurrent_tokens
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -775,11 +745,9 @@ class ContinuousBatchProcessor:
|
|||||||
|
|
||||||
self.requests_in_batch: list[RequestState] = []
|
self.requests_in_batch: list[RequestState] = []
|
||||||
|
|
||||||
# Get batch size parameters from generation config
|
|
||||||
self._configure_batch_parameters()
|
|
||||||
|
|
||||||
# Set up metrics collector
|
# Set up metrics collector
|
||||||
self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens)
|
self.max_batch_tokens = cache.max_batch_tokens
|
||||||
|
self.metrics = ContinuousBatchProcessorMetrics(cache.max_batch_tokens)
|
||||||
|
|
||||||
self.setup_static_tensors()
|
self.setup_static_tensors()
|
||||||
|
|
||||||
@@ -847,25 +815,6 @@ class ContinuousBatchProcessor:
|
|||||||
+ self.get_model_kwargs().__repr__()
|
+ self.get_model_kwargs().__repr__()
|
||||||
)
|
)
|
||||||
|
|
||||||
@traced(standalone=True)
|
|
||||||
def _configure_batch_parameters(self):
|
|
||||||
"""Set up batch processing parameters based on generation config."""
|
|
||||||
# Calculate total cache capacity
|
|
||||||
total_cache_tokens = self.cache.num_blocks * self.cache.block_size
|
|
||||||
|
|
||||||
# Get or calculate max tokens per batch
|
|
||||||
user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None)
|
|
||||||
if user_batch_tokens is not None:
|
|
||||||
self.max_batch_tokens = user_batch_tokens
|
|
||||||
else:
|
|
||||||
# Default to 1/8 of total cache capacity, adjusted for context
|
|
||||||
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
|
|
||||||
recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len)
|
|
||||||
self.max_batch_tokens = max(64, recommended_batch_size)
|
|
||||||
|
|
||||||
# Context length and EOS token
|
|
||||||
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
|
|
||||||
|
|
||||||
@traced
|
@traced
|
||||||
def _get_new_requests(self):
|
def _get_new_requests(self):
|
||||||
"""Pull new requests from the input queue and add to waiting list."""
|
"""Pull new requests from the input queue and add to waiting list."""
|
||||||
@@ -1041,6 +990,8 @@ class ContinuousBatchProcessor:
|
|||||||
self._maybe_send_output(state, token)
|
self._maybe_send_output(state, token)
|
||||||
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
||||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||||
|
if self.cache.get_num_free_blocks() == 0:
|
||||||
|
raise ValueError("No more free blocks")
|
||||||
|
|
||||||
@traced
|
@traced
|
||||||
def has_pending_requests(self) -> bool:
|
def has_pending_requests(self) -> bool:
|
||||||
@@ -1062,7 +1013,9 @@ class ContinuousBatchProcessor:
|
|||||||
Args:
|
Args:
|
||||||
error: The error to report in the failure message
|
error: The error to report in the failure message
|
||||||
"""
|
"""
|
||||||
for state in self.scheduler.active_requests.values():
|
|
||||||
|
requests = list(self.scheduler.active_requests.values())
|
||||||
|
for state in requests:
|
||||||
self._handle_request_error(error, state)
|
self._handle_request_error(error, state)
|
||||||
self.scheduler.finish_request(state.request_id)
|
self.scheduler.finish_request(state.request_id)
|
||||||
|
|
||||||
@@ -1296,6 +1249,7 @@ class ContinuousBatchingManager:
|
|||||||
self.generation_config,
|
self.generation_config,
|
||||||
self.model.device,
|
self.model.device,
|
||||||
self.model.dtype,
|
self.model.dtype,
|
||||||
|
num_requests=len(self.input_queue.queue),
|
||||||
tp_size=getattr(self.model, "tp_size"),
|
tp_size=getattr(self.model, "tp_size"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1324,30 +1278,7 @@ class ContinuousBatchingManager:
|
|||||||
)
|
)
|
||||||
self.batch_processor = batch_processor
|
self.batch_processor = batch_processor
|
||||||
is_first = True
|
is_first = True
|
||||||
|
while (not self.stop_event.is_set()) or batch_processor.has_pending_requests():
|
||||||
if self.profile:
|
|
||||||
tracing_schedule = schedule(skip_first=2, warmup=1, active=1, repeat=3, wait=1)
|
|
||||||
trace_handler = tensorboard_trace_handler(
|
|
||||||
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
|
|
||||||
)
|
|
||||||
activities = [
|
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
]
|
|
||||||
with profile(
|
|
||||||
activities=activities,
|
|
||||||
schedule=tracing_schedule,
|
|
||||||
on_trace_ready=trace_handler,
|
|
||||||
record_shapes=False,
|
|
||||||
with_stack=True,
|
|
||||||
) as prof:
|
|
||||||
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
|
|
||||||
self._inner_generation_loop(batch_processor, is_first)
|
|
||||||
if is_first:
|
|
||||||
is_first = False
|
|
||||||
prof.step()
|
|
||||||
else:
|
|
||||||
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
|
|
||||||
self._inner_generation_loop(batch_processor, is_first)
|
self._inner_generation_loop(batch_processor, is_first)
|
||||||
if is_first:
|
if is_first:
|
||||||
is_first = False
|
is_first = False
|
||||||
@@ -1363,6 +1294,8 @@ class ContinuousBatchingManager:
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
batch_processor.prepare_next_batch()
|
batch_processor.prepare_next_batch()
|
||||||
|
device, total, reserved, allocated = get_device_and_memory()
|
||||||
|
logger.info(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
|
||||||
if torch.cuda.is_available() and self.use_cuda_graph:
|
if torch.cuda.is_available() and self.use_cuda_graph:
|
||||||
if is_first:
|
if is_first:
|
||||||
self.warmup(batch_processor)
|
self.warmup(batch_processor)
|
||||||
@@ -1502,6 +1435,7 @@ class ContinuousMixin:
|
|||||||
results[req_id] = result
|
results[req_id] = result
|
||||||
finished_count += 1
|
finished_count += 1
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens))
|
||||||
else:
|
else:
|
||||||
if not manager.is_running():
|
if not manager.is_running():
|
||||||
logger.error("Generation thread terminated unexpectedly.")
|
logger.error("Generation thread terminated unexpectedly.")
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ from ..generation.continuous_batching import PagedAttentionCache
|
|||||||
from ..utils import is_flash_attn_2_available
|
from ..utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
try:
|
||||||
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def paged_attention_forward(
|
def paged_attention_forward(
|
||||||
|
|||||||
@@ -2705,9 +2705,9 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
kernel_function = partial(attention_wrapper, implementation=kernel)
|
kernel_function = partial(attention_wrapper, implementation=kernel)
|
||||||
elif kernel_name is not None:
|
elif kernel_name is not None:
|
||||||
kernel_function = getattr(kernel, kernel_name)
|
kernel_function = getattr(kernel, kernel_name)
|
||||||
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function)
|
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
|
||||||
ALL_MASK_ATTENTION_FUNCTIONS.register(
|
ALL_MASK_ATTENTION_FUNCTIONS.register(
|
||||||
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
|
attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -2715,8 +2715,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
|||||||
"default attention implementation instead (sdpa if available, eager otherwise)."
|
"default attention implementation instead (sdpa if available, eager otherwise)."
|
||||||
)
|
)
|
||||||
|
|
||||||
applicable_attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
attn_implementation = "sdpa" # Try to fallback to sdpa in this case
|
||||||
return applicable_attn_implementation
|
return attn_implementation
|
||||||
else:
|
else:
|
||||||
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
|
return self.get_correct_attn_implementation(applicable_attn_implementation, is_init_check)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user