Mini SGLang Part 2: Batching & Advanced Scheduling
In Part 1, we explored SGLang's architecture, engine initialization, and the lifecycle of a single request. Now we dive deeper into the sophisticated batching and scheduling mechanisms that make SGLang one of the fastest LLM serving systems available.
Continuous Batching: The Foundation of Throughput
Traditional static batching waits until a batch is full or a timeout expires before processing. Continuous batching revolutionizes this by dynamically adding and removing requests at each iteration.
class ContinuousBatchScheduler:
def __init__(self, model: nn.Module, max_batch_size: int = 32):
self.model = model
self.max_batch_size = max_batch_size
self.active_requests: list[ActiveRequest] = []
self.waiting_queue: Queue[Request] = Queue()
async def process_loop(self):
"""Main processing loop - one iteration at a time."""
while True:
self._maybe_add_requests()
if not self.active_requests:
await asyncio.sleep(0.01)
continue
self._step_generation()
self._remove_completed_requests()
await asyncio.sleep(0)
Benefits of Continuous Batching:
- No wasted compute on padding
- New requests can start immediately
- Better GPU utilization (typically 2-3x improvement)
- More predictable latency
Prefill vs Decode: Understanding the Two Phases
LLM inference has two fundamentally different phases:
Prefill Phase:
- Process the entire prompt at once
- Compute-bound (lots of matrix multiplications)
- High arithmetic intensity
- Can process many tokens in parallel
Decode Phase:
- Generate one token at a time
- Memory-bound (loading KV cache dominates)
- Low arithmetic intensity
- Sequential by nature
def analyze_phase_characteristics(batch_size: int, seq_len: int):
# Prefill: High compute, parallelizable
prefill_flops = batch_size * seq_len * model_flops_per_token
prefill_memory = model_weights_size # One-time load
# Decode: Low compute, memory-dominated
decode_flops = batch_size * 1 * model_flops_per_token
decode_memory = batch_size * kv_cache_size_per_request
prefill_arithmetic_intensity = prefill_flops / prefill_memory
decode_arithmetic_intensity = decode_flops / decode_memory
# Prefill is typically 10-100x higher arithmetic intensity
RadixAttention: Prefix Caching for Efficiency
RadixAttention is SGLang's innovative approach to prefix caching using a radix tree (trie) data structure. When multiple requests share a common prefix (like a system prompt), they can share the KV cache for those tokens.
class RadixCache:
def __init__(self):
self.root = RadixTreeNode()
def match_prefix(self, token_ids: List[int]) -> Tuple[List[int], List[int]]:
"""Find the longest cached prefix matching token_ids."""
current_node = self.root
matched_tokens = []
matched_blocks = []
for token_id in token_ids:
if token_id in current_node.children:
current_node = current_node.children[token_id]
matched_tokens.append(token_id)
if current_node.kv_block_id is not None:
matched_blocks.append(current_node.kv_block_id)
else:
break
return matched_tokens, matched_blocks
RadixAttention Benefits:
- Up to 40x speedup for requests with cached prefixes
- Efficient memory usage through block sharing
- Automatic cache management with LRU eviction
Overlap Scheduling: Maximizing GPU Utilization
SGLang's overlap scheduling interleaves prefill and decode operations to keep the GPU continuously busy:
class OverlapScheduler:
def schedule_iteration(self) -> ExecutionPlan:
prefill_requests = self.get_pending_prefills()
decode_requests = self.get_active_decodes()
# Calculate available compute budget
prefill_tokens = sum(len(r.prompt_tokens) - r.cached_tokens
for r in prefill_requests)
decode_tokens = len(decode_requests)
# Balance prefill and decode based on GPU utilization
if self.is_compute_bound():
# Prioritize prefill (better arithmetic intensity)
return self.schedule_prefill_heavy(prefill_requests, decode_requests)
else:
# Prioritize decode (reduce latency for active requests)
return self.schedule_decode_heavy(prefill_requests, decode_requests)
Chunked Prefill
Long prompts can block decode operations for hundreds of milliseconds. Chunked prefill breaks large prefills into smaller chunks:
def chunked_prefill(prompt_tokens: List[int], chunk_size: int = 512):
"""Process prompt in chunks to allow decode interleaving."""
for i in range(0, len(prompt_tokens), chunk_size):
chunk = prompt_tokens[i:i + chunk_size]
yield chunk
# After each chunk, allow decode operations to run
Tensor Parallelism: Scaling Across GPUs
For models too large to fit on a single GPU, tensor parallelism splits the computation across multiple GPUs:
class TensorParallelLinear:
def __init__(self, in_features: int, out_features: int,
tp_size: int, tp_rank: int, parallel_mode: str = "column"):
self.tp_size = tp_size
self.tp_rank = tp_rank
if parallel_mode == "column":
# Each GPU holds a slice of output dimension
self.weight = nn.Parameter(
torch.randn(in_features, out_features // tp_size)
)
else: # row parallel
# Each GPU holds a slice of input dimension
self.weight = nn.Parameter(
torch.randn(in_features // tp_size, out_features)
)
Communication Patterns:
- Column parallel: No communication needed until reduction
- Row parallel: All-reduce after computation
- Total: 2 all-reduce operations per transformer layer
Performance Optimizations
Flash Attention
FlashAttention reduces memory usage from O(N²) to O(N) through kernel fusion and tiling:
def flash_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Memory-efficient attention using tiling and online softmax."""
# Process in tiles that fit in SRAM
# Compute attention incrementally without materializing full matrix
# Result: 10-20x memory reduction, 2-4x speedup
Paged Attention
Manage KV cache memory using virtual memory techniques:
- Eliminate memory fragmentation
- Dynamic allocation without reserving max sequence length
- Memory sharing for prefix caching
Speculative Decoding
Use a smaller draft model to propose multiple tokens, then verify with the target model:
def speculative_decode(draft_model, target_model, prompt, draft_length=5):
# Step 1: Draft model generates K tokens
draft_tokens = draft_model.generate(prompt, max_tokens=draft_length)
# Step 2: Target model verifies all K tokens in ONE forward pass
target_logits = target_model.forward(prompt + draft_tokens)
# Step 3: Accept/reject each draft token
# Average speedup: 2-4x with 80%+ acceptance rate
Conclusion
SGLang's advanced batching and scheduling techniques represent the state-of-the-art in LLM serving:
- Continuous batching maximizes GPU utilization
- RadixAttention eliminates redundant computation through prefix caching
- Overlap scheduling balances prefill and decode for optimal throughput
- Tensor parallelism enables efficient multi-GPU serving
Understanding these mechanisms is essential for building high-performance LLM applications. In Part 3, we'll explore production deployment considerations and advanced features like structured output generation.