Back to posts

Mini SGLang (Part 1) - Architecture, Engine & Request Flow

·4 min read·llm·sglang·architecture

Mini SGLang Part 1: Architecture, Engine & Request Flow

SGLang (Structured Generation Language) has emerged as one of the fastest LLM inference engines, achieving remarkable throughput through innovations in memory management and scheduling. In this deep dive, we'll build understanding from the ground up.

System Architecture Overview

SGLang is built around several key components:

┌─────────────────────────────────────────────────────────────┐
│                        API Server                           │
│  (FastAPI/HTTP endpoints, request validation, streaming)    │
└─────────────────────────┬───────────────────────────────────┘
                          │
┌─────────────────────────▼───────────────────────────────────┐
│                       Scheduler                             │
│  (Request queuing, priority, continuous batching)           │
└─────────────────────────┬───────────────────────────────────┘
                          │
┌─────────────────────────▼───────────────────────────────────┐
│                    Model Runner                             │
│  (Forward pass execution, KV cache management)              │
└─────────────────────────┬───────────────────────────────────┘
                          │
┌─────────────────────────▼───────────────────────────────────┐
│                  Memory Manager                             │
│  (Paged KV cache, RadixCache for prefix sharing)            │
└─────────────────────────────────────────────────────────────┘

Engine Initialization

When SGLang starts, several critical initialization steps occur:

Step 1: Model Loading

def load_model(model_path: str, config: ModelConfig) -> nn.Module:
    """Load model weights from disk into GPU memory."""
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map="auto"
    )

    # Apply optimizations
    if config.use_flash_attention:
        model = patch_for_flash_attention(model)

    return model.eval()

Step 2: KV Cache Initialization

The KV cache is the largest memory consumer after model weights:

class PagedKVCache:
    def __init__(self, num_layers, num_kv_heads, head_dim,
                 block_size, max_num_blocks, dtype, device):
        self.num_layers = num_layers
        self.block_size = block_size
        self.max_num_blocks = max_num_blocks

        # Allocate physical page pool
        # Shape: [num_layers, 2, max_num_blocks, block_size, num_kv_heads, head_dim]
        self.kv_cache = torch.zeros(
            num_layers, 2, max_num_blocks, block_size,
            num_kv_heads, head_dim,
            dtype=dtype, device=device
        )

        self.free_blocks = set(range(max_num_blocks))
        self.request_to_blocks = {}

For a 7B model with 32 layers, 32 KV heads, and 128 head dimension:

KV cache size = 2 × 32 × batch × seq_len × 32 × 128 × 2 bytes
             = 32 GB for batch=32, seq_len=2048

Step 3: CUDA Kernel Compilation

SGLang uses custom CUDA kernels for critical operations:

def initialize_cuda_kernels(model_config, config):
    paged_attention_kernel = compile_paged_attention_kernel(
        head_size=model_config.hidden_size // model_config.num_attention_heads,
        num_kv_heads=model_config.num_kv_heads,
        block_size=config.block_size
    )

    sampling_kernel = compile_sampling_kernel(
        vocab_size=model_config.vocab_size,
        max_batch_size=config.max_batch_size
    )

    return {'paged_attention': paged_attention_kernel, 'sampling': sampling_kernel}

Request Lifecycle

Phase 1: Request Arrival

@app.post("/v1/generate")
async def generate(request: GenerateRequest):
    # Validate request
    if not request.prompt:
        raise ValueError("Prompt cannot be empty")

    # Tokenize prompt
    input_ids = tokenizer.encode(request.prompt, add_special_tokens=True)

    # Create internal request object
    internal_request = InternalRequest(
        request_id=generate_request_id(),
        input_ids=input_ids,
        sampling_params=SamplingParams(
            temperature=request.temperature,
            top_p=request.top_p,
            max_tokens=request.max_tokens
        )
    )

    # Enqueue request
    await scheduler.add_request(internal_request)

    return await internal_request.wait_for_completion()

Phase 2: Request Admission

The scheduler admits requests when sufficient memory is available:

def _admit_new_requests(self):
    while self.waiting_queue:
        request = self.waiting_queue[0]

        num_tokens = len(request.input_ids)
        required_blocks = math.ceil(num_tokens / self.block_size)

        if len(self.cache_manager.get_free_blocks()) < required_blocks:
            break  # Not enough memory

        request = self.waiting_queue.popleft()
        cache_info = self.cache_manager.allocate_for_request(
            request.request_id, request.input_ids
        )

        request.kv_cache_info = cache_info
        self.running_requests.append(request)

Phase 3: Prefill Phase

Process all input tokens to populate the KV cache:

def _execute_prefill(self, request):
    prefill_token_ids = request.input_ids[request.num_computed_tokens:]

    prefill_batch = ExecutionBatch(
        request_ids=[request.request_id],
        input_token_ids=torch.tensor([prefill_token_ids]),
        block_tables=[request.kv_cache_info['total_blocks']],
        is_prefill=True
    )

    logits = self.engine.forward(prefill_batch)
    next_token = self._sample_token(logits[:, -1, :], request.sampling_params)

    request.output_tokens.append(next_token.item())
    request.status = RequestStatus.GENERATING

Phase 4: Decode Phase

Generate tokens one at a time:

def _execute_decode(self, running_requests):
    input_token_ids = torch.tensor([
        [req.output_tokens[-1]] for req in running_requests
    ])

    decode_batch = ExecutionBatch(
        request_ids=[req.request_id for req in running_requests],
        input_token_ids=input_token_ids,
        block_tables=[req.kv_cache_info['total_blocks'] for req in running_requests],
        is_prefill=False
    )

    logits = self.engine.forward(decode_batch)
    next_tokens = self._sample_tokens(logits, running_requests)

    for request, next_token in zip(running_requests, next_tokens):
        request.output_tokens.append(next_token.item())
        if self._should_stop(request):
            request.status = RequestStatus.FINISHED

Memory Management

Paged KV Cache

Instead of allocating contiguous memory, SGLang uses fixed-size pages:

class CacheManager:
    def allocate_for_request(self, request_id: str, input_ids: List[int]):
        # Check for cached prefix using RadixCache
        matched_prefix, matched_blocks = self.radix_cache.match_prefix(input_ids)

        num_cached_tokens = len(matched_prefix)
        num_new_tokens = len(input_ids) - num_cached_tokens
        num_new_blocks = math.ceil(num_new_tokens / self.kv_cache.block_size)

        new_blocks = self._allocate_blocks(num_new_blocks)
        all_blocks = matched_blocks + new_blocks

        self.kv_cache.request_to_blocks[request_id] = all_blocks

        return {
            'num_cached_tokens': num_cached_tokens,
            'total_blocks': all_blocks
        }

RadixCache for Prefix Sharing

Multiple requests can share KV cache pages for common prefixes:

class RadixCache:
    def match_prefix(self, token_ids: List[int]):
        current_node = self.root
        matched_tokens = []
        matched_blocks = []

        for token_id in token_ids:
            if token_id in current_node.children:
                current_node = current_node.children[token_id]
                matched_tokens.append(token_id)
                if current_node.kv_block_id is not None:
                    matched_blocks.append(current_node.kv_block_id)
            else:
                break

        return matched_tokens, matched_blocks

Token Generation: Sampling

def sample_tokens(logits: torch.Tensor, sampling_params: List[SamplingParams]):
    batch_size, vocab_size = logits.shape
    next_tokens = torch.zeros(batch_size, dtype=torch.long)

    for i in range(batch_size):
        params = sampling_params[i]
        batch_logits = logits[i]

        # Temperature scaling
        if params.temperature != 1.0:
            batch_logits = batch_logits / params.temperature

        # Top-k filtering
        if params.top_k > 0:
            top_k_logits, top_k_indices = torch.topk(batch_logits, k=params.top_k)
            batch_logits = torch.full_like(batch_logits, float('-inf'))
            batch_logits[top_k_indices] = top_k_logits

        # Top-p (nucleus) filtering
        if params.top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > params.top_p
            batch_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf')

        probs = torch.softmax(batch_logits, dim=-1)
        next_tokens[i] = torch.multinomial(probs, num_samples=1)

    return next_tokens

Conclusion

We've covered the core architecture of SGLang:

  • System component overview and initialization
  • Request lifecycle from arrival to completion
  • Paged KV cache with RadixCache for prefix sharing
  • Token generation with various sampling strategies

In Part 2, we'll explore continuous batching, overlap scheduling, and tensor parallelism that make SGLang so fast.