Back to posts

Mini SGLang (Part 2) - Batching & Advanced Scheduling

·5 min read·llm·sglang·deep-dive

Mini SGLang Part 2: Batching & Advanced Scheduling

In Part 1, we explored SGLang's architecture, engine initialization, and the lifecycle of a single request. Now we dive deeper into the sophisticated batching and scheduling mechanisms that make SGLang one of the fastest LLM serving systems available.

Continuous Batching: The Foundation of Throughput

Traditional static batching waits until a batch is full or a timeout expires before processing. Continuous batching revolutionizes this by dynamically adding and removing requests at each iteration.

class ContinuousBatchScheduler:
    def __init__(self, model: nn.Module, max_batch_size: int = 32):
        self.model = model
        self.max_batch_size = max_batch_size
        self.active_requests: list[ActiveRequest] = []
        self.waiting_queue: Queue[Request] = Queue()

    async def process_loop(self):
        """Main processing loop - one iteration at a time."""
        while True:
            self._maybe_add_requests()

            if not self.active_requests:
                await asyncio.sleep(0.01)
                continue

            self._step_generation()
            self._remove_completed_requests()
            await asyncio.sleep(0)

Benefits of Continuous Batching:

  • No wasted compute on padding
  • New requests can start immediately
  • Better GPU utilization (typically 2-3x improvement)
  • More predictable latency

Prefill vs Decode: Understanding the Two Phases

LLM inference has two fundamentally different phases:

Prefill Phase:

  • Process the entire prompt at once
  • Compute-bound (lots of matrix multiplications)
  • High arithmetic intensity
  • Can process many tokens in parallel

Decode Phase:

  • Generate one token at a time
  • Memory-bound (loading KV cache dominates)
  • Low arithmetic intensity
  • Sequential by nature
def analyze_phase_characteristics(batch_size: int, seq_len: int):
    # Prefill: High compute, parallelizable
    prefill_flops = batch_size * seq_len * model_flops_per_token
    prefill_memory = model_weights_size  # One-time load

    # Decode: Low compute, memory-dominated
    decode_flops = batch_size * 1 * model_flops_per_token
    decode_memory = batch_size * kv_cache_size_per_request

    prefill_arithmetic_intensity = prefill_flops / prefill_memory
    decode_arithmetic_intensity = decode_flops / decode_memory

    # Prefill is typically 10-100x higher arithmetic intensity

RadixAttention: Prefix Caching for Efficiency

RadixAttention is SGLang's innovative approach to prefix caching using a radix tree (trie) data structure. When multiple requests share a common prefix (like a system prompt), they can share the KV cache for those tokens.

class RadixCache:
    def __init__(self):
        self.root = RadixTreeNode()

    def match_prefix(self, token_ids: List[int]) -> Tuple[List[int], List[int]]:
        """Find the longest cached prefix matching token_ids."""
        current_node = self.root
        matched_tokens = []
        matched_blocks = []

        for token_id in token_ids:
            if token_id in current_node.children:
                current_node = current_node.children[token_id]
                matched_tokens.append(token_id)
                if current_node.kv_block_id is not None:
                    matched_blocks.append(current_node.kv_block_id)
            else:
                break

        return matched_tokens, matched_blocks

RadixAttention Benefits:

  • Up to 40x speedup for requests with cached prefixes
  • Efficient memory usage through block sharing
  • Automatic cache management with LRU eviction

Overlap Scheduling: Maximizing GPU Utilization

SGLang's overlap scheduling interleaves prefill and decode operations to keep the GPU continuously busy:

class OverlapScheduler:
    def schedule_iteration(self) -> ExecutionPlan:
        prefill_requests = self.get_pending_prefills()
        decode_requests = self.get_active_decodes()

        # Calculate available compute budget
        prefill_tokens = sum(len(r.prompt_tokens) - r.cached_tokens
                           for r in prefill_requests)
        decode_tokens = len(decode_requests)

        # Balance prefill and decode based on GPU utilization
        if self.is_compute_bound():
            # Prioritize prefill (better arithmetic intensity)
            return self.schedule_prefill_heavy(prefill_requests, decode_requests)
        else:
            # Prioritize decode (reduce latency for active requests)
            return self.schedule_decode_heavy(prefill_requests, decode_requests)

Chunked Prefill

Long prompts can block decode operations for hundreds of milliseconds. Chunked prefill breaks large prefills into smaller chunks:

def chunked_prefill(prompt_tokens: List[int], chunk_size: int = 512):
    """Process prompt in chunks to allow decode interleaving."""
    for i in range(0, len(prompt_tokens), chunk_size):
        chunk = prompt_tokens[i:i + chunk_size]
        yield chunk
        # After each chunk, allow decode operations to run

Tensor Parallelism: Scaling Across GPUs

For models too large to fit on a single GPU, tensor parallelism splits the computation across multiple GPUs:

class TensorParallelLinear:
    def __init__(self, in_features: int, out_features: int,
                 tp_size: int, tp_rank: int, parallel_mode: str = "column"):
        self.tp_size = tp_size
        self.tp_rank = tp_rank

        if parallel_mode == "column":
            # Each GPU holds a slice of output dimension
            self.weight = nn.Parameter(
                torch.randn(in_features, out_features // tp_size)
            )
        else:  # row parallel
            # Each GPU holds a slice of input dimension
            self.weight = nn.Parameter(
                torch.randn(in_features // tp_size, out_features)
            )

Communication Patterns:

  • Column parallel: No communication needed until reduction
  • Row parallel: All-reduce after computation
  • Total: 2 all-reduce operations per transformer layer

Performance Optimizations

Flash Attention

FlashAttention reduces memory usage from O(N²) to O(N) through kernel fusion and tiling:

def flash_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    """Memory-efficient attention using tiling and online softmax."""
    # Process in tiles that fit in SRAM
    # Compute attention incrementally without materializing full matrix
    # Result: 10-20x memory reduction, 2-4x speedup

Paged Attention

Manage KV cache memory using virtual memory techniques:

  • Eliminate memory fragmentation
  • Dynamic allocation without reserving max sequence length
  • Memory sharing for prefix caching

Speculative Decoding

Use a smaller draft model to propose multiple tokens, then verify with the target model:

def speculative_decode(draft_model, target_model, prompt, draft_length=5):
    # Step 1: Draft model generates K tokens
    draft_tokens = draft_model.generate(prompt, max_tokens=draft_length)

    # Step 2: Target model verifies all K tokens in ONE forward pass
    target_logits = target_model.forward(prompt + draft_tokens)

    # Step 3: Accept/reject each draft token
    # Average speedup: 2-4x with 80%+ acceptance rate

Conclusion

SGLang's advanced batching and scheduling techniques represent the state-of-the-art in LLM serving:

  • Continuous batching maximizes GPU utilization
  • RadixAttention eliminates redundant computation through prefix caching
  • Overlap scheduling balances prefill and decode for optimal throughput
  • Tensor parallelism enables efficient multi-GPU serving

Understanding these mechanisms is essential for building high-performance LLM applications. In Part 3, we'll explore production deployment considerations and advanced features like structured output generation.