Back to posts

Mini SGLang (Part 1) - Architecture, Engine & Request Flow

·4 min read·llm·sglang·architecture

Mini SGLang Part 1: Architecture, Engine & Request Flow#

SGLang (Structured Generation Language) has emerged as one of the fastest LLM inference engines, achieving remarkable throughput through innovations in memory management and scheduling. In this deep dive, we'll build understanding from the ground up.

System Architecture Overview#

SGLang is built around several key components:

┌─────────────────────────────────────────────────────────────┐
│                        API Server                           │
│  (FastAPI/HTTP endpoints, request validation, streaming)    │
└─────────────────────────┬───────────────────────────────────┘

┌─────────────────────────▼───────────────────────────────────┐
│                       Scheduler                             │
│  (Request queuing, priority, continuous batching)           │
└─────────────────────────┬───────────────────────────────────┘

┌─────────────────────────▼───────────────────────────────────┐
│                    Model Runner                             │
│  (Forward pass execution, KV cache management)              │
└─────────────────────────┬───────────────────────────────────┘

┌─────────────────────────▼───────────────────────────────────┐
│                  Memory Manager                             │
│  (Paged KV cache, RadixCache for prefix sharing)            │
└─────────────────────────────────────────────────────────────┘

Engine Initialization#

When SGLang starts, several critical initialization steps occur:

Step 1: Model Loading#

def load_model(model_path: str, config: ModelConfig) -> nn.Module:
    """Load model weights from disk into GPU memory."""
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map="auto"
    )
 
    # Apply optimizations
    if config.use_flash_attention:
        model = patch_for_flash_attention(model)
 
    return model.eval()

Step 2: KV Cache Initialization#

The KV cache is the largest memory consumer after model weights:

class PagedKVCache:
    def __init__(self, num_layers, num_kv_heads, head_dim,
                 block_size, max_num_blocks, dtype, device):
        self.num_layers = num_layers
        self.block_size = block_size
        self.max_num_blocks = max_num_blocks
 
        # Allocate physical page pool
        # Shape: [num_layers, 2, max_num_blocks, block_size, num_kv_heads, head_dim]
        self.kv_cache = torch.zeros(
            num_layers, 2, max_num_blocks, block_size,
            num_kv_heads, head_dim,
            dtype=dtype, device=device
        )
 
        self.free_blocks = set(range(max_num_blocks))
        self.request_to_blocks = {}

For a 7B model with 32 layers, 32 KV heads, and 128 head dimension:

KV cache size = 2 × 32 × batch × seq_len × 32 × 128 × 2 bytes
             = 32 GB for batch=32, seq_len=2048

Step 3: CUDA Kernel Compilation#

SGLang uses custom CUDA kernels for critical operations:

def initialize_cuda_kernels(model_config, config):
    paged_attention_kernel = compile_paged_attention_kernel(
        head_size=model_config.hidden_size // model_config.num_attention_heads,
        num_kv_heads=model_config.num_kv_heads,
        block_size=config.block_size
    )
 
    sampling_kernel = compile_sampling_kernel(
        vocab_size=model_config.vocab_size,
        max_batch_size=config.max_batch_size
    )
 
    return {'paged_attention': paged_attention_kernel, 'sampling': sampling_kernel}

Request Lifecycle#

Phase 1: Request Arrival#

@app.post("/v1/generate")
async def generate(request: GenerateRequest):
    # Validate request
    if not request.prompt:
        raise ValueError("Prompt cannot be empty")
 
    # Tokenize prompt
    input_ids = tokenizer.encode(request.prompt, add_special_tokens=True)
 
    # Create internal request object
    internal_request = InternalRequest(
        request_id=generate_request_id(),
        input_ids=input_ids,
        sampling_params=SamplingParams(
            temperature=request.temperature,
            top_p=request.top_p,
            max_tokens=request.max_tokens
        )
    )
 
    # Enqueue request
    await scheduler.add_request(internal_request)
 
    return await internal_request.wait_for_completion()

Phase 2: Request Admission#

The scheduler admits requests when sufficient memory is available:

def _admit_new_requests(self):
    while self.waiting_queue:
        request = self.waiting_queue[0]
 
        num_tokens = len(request.input_ids)
        required_blocks = math.ceil(num_tokens / self.block_size)
 
        if len(self.cache_manager.get_free_blocks()) < required_blocks:
            break  # Not enough memory
 
        request = self.waiting_queue.popleft()
        cache_info = self.cache_manager.allocate_for_request(
            request.request_id, request.input_ids
        )
 
        request.kv_cache_info = cache_info
        self.running_requests.append(request)

Phase 3: Prefill Phase#

Process all input tokens to populate the KV cache:

def _execute_prefill(self, request):
    prefill_token_ids = request.input_ids[request.num_computed_tokens:]
 
    prefill_batch = ExecutionBatch(
        request_ids=[request.request_id],
        input_token_ids=torch.tensor([prefill_token_ids]),
        block_tables=[request.kv_cache_info['total_blocks']],
        is_prefill=True
    )
 
    logits = self.engine.forward(prefill_batch)
    next_token = self._sample_token(logits[:, -1, :], request.sampling_params)
 
    request.output_tokens.append(next_token.item())
    request.status = RequestStatus.GENERATING

Phase 4: Decode Phase#

Generate tokens one at a time:

def _execute_decode(self, running_requests):
    input_token_ids = torch.tensor([
        [req.output_tokens[-1]] for req in running_requests
    ])
 
    decode_batch = ExecutionBatch(
        request_ids=[req.request_id for req in running_requests],
        input_token_ids=input_token_ids,
        block_tables=[req.kv_cache_info['total_blocks'] for req in running_requests],
        is_prefill=False
    )
 
    logits = self.engine.forward(decode_batch)
    next_tokens = self._sample_tokens(logits, running_requests)
 
    for request, next_token in zip(running_requests, next_tokens):
        request.output_tokens.append(next_token.item())
        if self._should_stop(request):
            request.status = RequestStatus.FINISHED

Memory Management#

Paged KV Cache#

Instead of allocating contiguous memory, SGLang uses fixed-size pages:

class CacheManager:
    def allocate_for_request(self, request_id: str, input_ids: List[int]):
        # Check for cached prefix using RadixCache
        matched_prefix, matched_blocks = self.radix_cache.match_prefix(input_ids)
 
        num_cached_tokens = len(matched_prefix)
        num_new_tokens = len(input_ids) - num_cached_tokens
        num_new_blocks = math.ceil(num_new_tokens / self.kv_cache.block_size)
 
        new_blocks = self._allocate_blocks(num_new_blocks)
        all_blocks = matched_blocks + new_blocks
 
        self.kv_cache.request_to_blocks[request_id] = all_blocks
 
        return {
            'num_cached_tokens': num_cached_tokens,
            'total_blocks': all_blocks
        }

RadixCache for Prefix Sharing#

Multiple requests can share KV cache pages for common prefixes:

class RadixCache:
    def match_prefix(self, token_ids: List[int]):
        current_node = self.root
        matched_tokens = []
        matched_blocks = []
 
        for token_id in token_ids:
            if token_id in current_node.children:
                current_node = current_node.children[token_id]
                matched_tokens.append(token_id)
                if current_node.kv_block_id is not None:
                    matched_blocks.append(current_node.kv_block_id)
            else:
                break
 
        return matched_tokens, matched_blocks

Token Generation: Sampling#

def sample_tokens(logits: torch.Tensor, sampling_params: List[SamplingParams]):
    batch_size, vocab_size = logits.shape
    next_tokens = torch.zeros(batch_size, dtype=torch.long)
 
    for i in range(batch_size):
        params = sampling_params[i]
        batch_logits = logits[i]
 
        # Temperature scaling
        if params.temperature != 1.0:
            batch_logits = batch_logits / params.temperature
 
        # Top-k filtering
        if params.top_k > 0:
            top_k_logits, top_k_indices = torch.topk(batch_logits, k=params.top_k)
            batch_logits = torch.full_like(batch_logits, float('-inf'))
            batch_logits[top_k_indices] = top_k_logits
 
        # Top-p (nucleus) filtering
        if params.top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > params.top_p
            batch_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf')
 
        probs = torch.softmax(batch_logits, dim=-1)
        next_tokens[i] = torch.multinomial(probs, num_samples=1)
 
    return next_tokens

Conclusion#

We've covered the core architecture of SGLang:

  • System component overview and initialization
  • Request lifecycle from arrival to completion
  • Paged KV cache with RadixCache for prefix sharing
  • Token generation with various sampling strategies

In Part 2, we'll explore continuous batching, overlap scheduling, and tensor parallelism that make SGLang so fast.