Mini SGLang Part 1: Architecture, Engine & Request Flow#
SGLang (Structured Generation Language) has emerged as one of the fastest LLM inference engines, achieving remarkable throughput through innovations in memory management and scheduling. In this deep dive, we'll build understanding from the ground up.
System Architecture Overview#
SGLang is built around several key components:
┌─────────────────────────────────────────────────────────────┐
│ API Server │
│ (FastAPI/HTTP endpoints, request validation, streaming) │
└─────────────────────────┬───────────────────────────────────┘
│
┌─────────────────────────▼───────────────────────────────────┐
│ Scheduler │
│ (Request queuing, priority, continuous batching) │
└─────────────────────────┬───────────────────────────────────┘
│
┌─────────────────────────▼───────────────────────────────────┐
│ Model Runner │
│ (Forward pass execution, KV cache management) │
└─────────────────────────┬───────────────────────────────────┘
│
┌─────────────────────────▼───────────────────────────────────┐
│ Memory Manager │
│ (Paged KV cache, RadixCache for prefix sharing) │
└─────────────────────────────────────────────────────────────┘Engine Initialization#
When SGLang starts, several critical initialization steps occur:
Step 1: Model Loading#
def load_model(model_path: str, config: ModelConfig) -> nn.Module:
"""Load model weights from disk into GPU memory."""
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto"
)
# Apply optimizations
if config.use_flash_attention:
model = patch_for_flash_attention(model)
return model.eval()Step 2: KV Cache Initialization#
The KV cache is the largest memory consumer after model weights:
class PagedKVCache:
def __init__(self, num_layers, num_kv_heads, head_dim,
block_size, max_num_blocks, dtype, device):
self.num_layers = num_layers
self.block_size = block_size
self.max_num_blocks = max_num_blocks
# Allocate physical page pool
# Shape: [num_layers, 2, max_num_blocks, block_size, num_kv_heads, head_dim]
self.kv_cache = torch.zeros(
num_layers, 2, max_num_blocks, block_size,
num_kv_heads, head_dim,
dtype=dtype, device=device
)
self.free_blocks = set(range(max_num_blocks))
self.request_to_blocks = {}For a 7B model with 32 layers, 32 KV heads, and 128 head dimension:
KV cache size = 2 × 32 × batch × seq_len × 32 × 128 × 2 bytes
= 32 GB for batch=32, seq_len=2048Step 3: CUDA Kernel Compilation#
SGLang uses custom CUDA kernels for critical operations:
def initialize_cuda_kernels(model_config, config):
paged_attention_kernel = compile_paged_attention_kernel(
head_size=model_config.hidden_size // model_config.num_attention_heads,
num_kv_heads=model_config.num_kv_heads,
block_size=config.block_size
)
sampling_kernel = compile_sampling_kernel(
vocab_size=model_config.vocab_size,
max_batch_size=config.max_batch_size
)
return {'paged_attention': paged_attention_kernel, 'sampling': sampling_kernel}Request Lifecycle#
Phase 1: Request Arrival#
@app.post("/v1/generate")
async def generate(request: GenerateRequest):
# Validate request
if not request.prompt:
raise ValueError("Prompt cannot be empty")
# Tokenize prompt
input_ids = tokenizer.encode(request.prompt, add_special_tokens=True)
# Create internal request object
internal_request = InternalRequest(
request_id=generate_request_id(),
input_ids=input_ids,
sampling_params=SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens
)
)
# Enqueue request
await scheduler.add_request(internal_request)
return await internal_request.wait_for_completion()Phase 2: Request Admission#
The scheduler admits requests when sufficient memory is available:
def _admit_new_requests(self):
while self.waiting_queue:
request = self.waiting_queue[0]
num_tokens = len(request.input_ids)
required_blocks = math.ceil(num_tokens / self.block_size)
if len(self.cache_manager.get_free_blocks()) < required_blocks:
break # Not enough memory
request = self.waiting_queue.popleft()
cache_info = self.cache_manager.allocate_for_request(
request.request_id, request.input_ids
)
request.kv_cache_info = cache_info
self.running_requests.append(request)Phase 3: Prefill Phase#
Process all input tokens to populate the KV cache:
def _execute_prefill(self, request):
prefill_token_ids = request.input_ids[request.num_computed_tokens:]
prefill_batch = ExecutionBatch(
request_ids=[request.request_id],
input_token_ids=torch.tensor([prefill_token_ids]),
block_tables=[request.kv_cache_info['total_blocks']],
is_prefill=True
)
logits = self.engine.forward(prefill_batch)
next_token = self._sample_token(logits[:, -1, :], request.sampling_params)
request.output_tokens.append(next_token.item())
request.status = RequestStatus.GENERATINGPhase 4: Decode Phase#
Generate tokens one at a time:
def _execute_decode(self, running_requests):
input_token_ids = torch.tensor([
[req.output_tokens[-1]] for req in running_requests
])
decode_batch = ExecutionBatch(
request_ids=[req.request_id for req in running_requests],
input_token_ids=input_token_ids,
block_tables=[req.kv_cache_info['total_blocks'] for req in running_requests],
is_prefill=False
)
logits = self.engine.forward(decode_batch)
next_tokens = self._sample_tokens(logits, running_requests)
for request, next_token in zip(running_requests, next_tokens):
request.output_tokens.append(next_token.item())
if self._should_stop(request):
request.status = RequestStatus.FINISHEDMemory Management#
Paged KV Cache#
Instead of allocating contiguous memory, SGLang uses fixed-size pages:
class CacheManager:
def allocate_for_request(self, request_id: str, input_ids: List[int]):
# Check for cached prefix using RadixCache
matched_prefix, matched_blocks = self.radix_cache.match_prefix(input_ids)
num_cached_tokens = len(matched_prefix)
num_new_tokens = len(input_ids) - num_cached_tokens
num_new_blocks = math.ceil(num_new_tokens / self.kv_cache.block_size)
new_blocks = self._allocate_blocks(num_new_blocks)
all_blocks = matched_blocks + new_blocks
self.kv_cache.request_to_blocks[request_id] = all_blocks
return {
'num_cached_tokens': num_cached_tokens,
'total_blocks': all_blocks
}RadixCache for Prefix Sharing#
Multiple requests can share KV cache pages for common prefixes:
class RadixCache:
def match_prefix(self, token_ids: List[int]):
current_node = self.root
matched_tokens = []
matched_blocks = []
for token_id in token_ids:
if token_id in current_node.children:
current_node = current_node.children[token_id]
matched_tokens.append(token_id)
if current_node.kv_block_id is not None:
matched_blocks.append(current_node.kv_block_id)
else:
break
return matched_tokens, matched_blocksToken Generation: Sampling#
def sample_tokens(logits: torch.Tensor, sampling_params: List[SamplingParams]):
batch_size, vocab_size = logits.shape
next_tokens = torch.zeros(batch_size, dtype=torch.long)
for i in range(batch_size):
params = sampling_params[i]
batch_logits = logits[i]
# Temperature scaling
if params.temperature != 1.0:
batch_logits = batch_logits / params.temperature
# Top-k filtering
if params.top_k > 0:
top_k_logits, top_k_indices = torch.topk(batch_logits, k=params.top_k)
batch_logits = torch.full_like(batch_logits, float('-inf'))
batch_logits[top_k_indices] = top_k_logits
# Top-p (nucleus) filtering
if params.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > params.top_p
batch_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf')
probs = torch.softmax(batch_logits, dim=-1)
next_tokens[i] = torch.multinomial(probs, num_samples=1)
return next_tokensConclusion#
We've covered the core architecture of SGLang:
- System component overview and initialization
- Request lifecycle from arrival to completion
- Paged KV cache with RadixCache for prefix sharing
- Token generation with various sampling strategies
In Part 2, we'll explore continuous batching, overlap scheduling, and tensor parallelism that make SGLang so fast.