Mini SGLang Part 1: Architecture, Engine & Request Flow
SGLang (Structured Generation Language) has emerged as one of the fastest LLM inference engines, achieving remarkable throughput through innovations in memory management and scheduling. In this deep dive, we'll build understanding from the ground up.
System Architecture Overview
SGLang is built around several key components:
┌─────────────────────────────────────────────────────────────┐
│ API Server │
│ (FastAPI/HTTP endpoints, request validation, streaming) │
└─────────────────────────┬───────────────────────────────────┘
│
┌─────────────────────────▼───────────────────────────────────┐
│ Scheduler │
│ (Request queuing, priority, continuous batching) │
└─────────────────────────┬───────────────────────────────────┘
│
┌─────────────────────────▼───────────────────────────────────┐
│ Model Runner │
│ (Forward pass execution, KV cache management) │
└─────────────────────────┬───────────────────────────────────┘
│
┌─────────────────────────▼───────────────────────────────────┐
│ Memory Manager │
│ (Paged KV cache, RadixCache for prefix sharing) │
└─────────────────────────────────────────────────────────────┘
Engine Initialization
When SGLang starts, several critical initialization steps occur:
Step 1: Model Loading
def load_model(model_path: str, config: ModelConfig) -> nn.Module:
"""Load model weights from disk into GPU memory."""
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto"
)
# Apply optimizations
if config.use_flash_attention:
model = patch_for_flash_attention(model)
return model.eval()
Step 2: KV Cache Initialization
The KV cache is the largest memory consumer after model weights:
class PagedKVCache:
def __init__(self, num_layers, num_kv_heads, head_dim,
block_size, max_num_blocks, dtype, device):
self.num_layers = num_layers
self.block_size = block_size
self.max_num_blocks = max_num_blocks
# Allocate physical page pool
# Shape: [num_layers, 2, max_num_blocks, block_size, num_kv_heads, head_dim]
self.kv_cache = torch.zeros(
num_layers, 2, max_num_blocks, block_size,
num_kv_heads, head_dim,
dtype=dtype, device=device
)
self.free_blocks = set(range(max_num_blocks))
self.request_to_blocks = {}
For a 7B model with 32 layers, 32 KV heads, and 128 head dimension:
KV cache size = 2 × 32 × batch × seq_len × 32 × 128 × 2 bytes
= 32 GB for batch=32, seq_len=2048
Step 3: CUDA Kernel Compilation
SGLang uses custom CUDA kernels for critical operations:
def initialize_cuda_kernels(model_config, config):
paged_attention_kernel = compile_paged_attention_kernel(
head_size=model_config.hidden_size // model_config.num_attention_heads,
num_kv_heads=model_config.num_kv_heads,
block_size=config.block_size
)
sampling_kernel = compile_sampling_kernel(
vocab_size=model_config.vocab_size,
max_batch_size=config.max_batch_size
)
return {'paged_attention': paged_attention_kernel, 'sampling': sampling_kernel}
Request Lifecycle
Phase 1: Request Arrival
@app.post("/v1/generate")
async def generate(request: GenerateRequest):
# Validate request
if not request.prompt:
raise ValueError("Prompt cannot be empty")
# Tokenize prompt
input_ids = tokenizer.encode(request.prompt, add_special_tokens=True)
# Create internal request object
internal_request = InternalRequest(
request_id=generate_request_id(),
input_ids=input_ids,
sampling_params=SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens
)
)
# Enqueue request
await scheduler.add_request(internal_request)
return await internal_request.wait_for_completion()
Phase 2: Request Admission
The scheduler admits requests when sufficient memory is available:
def _admit_new_requests(self):
while self.waiting_queue:
request = self.waiting_queue[0]
num_tokens = len(request.input_ids)
required_blocks = math.ceil(num_tokens / self.block_size)
if len(self.cache_manager.get_free_blocks()) < required_blocks:
break # Not enough memory
request = self.waiting_queue.popleft()
cache_info = self.cache_manager.allocate_for_request(
request.request_id, request.input_ids
)
request.kv_cache_info = cache_info
self.running_requests.append(request)
Phase 3: Prefill Phase
Process all input tokens to populate the KV cache:
def _execute_prefill(self, request):
prefill_token_ids = request.input_ids[request.num_computed_tokens:]
prefill_batch = ExecutionBatch(
request_ids=[request.request_id],
input_token_ids=torch.tensor([prefill_token_ids]),
block_tables=[request.kv_cache_info['total_blocks']],
is_prefill=True
)
logits = self.engine.forward(prefill_batch)
next_token = self._sample_token(logits[:, -1, :], request.sampling_params)
request.output_tokens.append(next_token.item())
request.status = RequestStatus.GENERATING
Phase 4: Decode Phase
Generate tokens one at a time:
def _execute_decode(self, running_requests):
input_token_ids = torch.tensor([
[req.output_tokens[-1]] for req in running_requests
])
decode_batch = ExecutionBatch(
request_ids=[req.request_id for req in running_requests],
input_token_ids=input_token_ids,
block_tables=[req.kv_cache_info['total_blocks'] for req in running_requests],
is_prefill=False
)
logits = self.engine.forward(decode_batch)
next_tokens = self._sample_tokens(logits, running_requests)
for request, next_token in zip(running_requests, next_tokens):
request.output_tokens.append(next_token.item())
if self._should_stop(request):
request.status = RequestStatus.FINISHED
Memory Management
Paged KV Cache
Instead of allocating contiguous memory, SGLang uses fixed-size pages:
class CacheManager:
def allocate_for_request(self, request_id: str, input_ids: List[int]):
# Check for cached prefix using RadixCache
matched_prefix, matched_blocks = self.radix_cache.match_prefix(input_ids)
num_cached_tokens = len(matched_prefix)
num_new_tokens = len(input_ids) - num_cached_tokens
num_new_blocks = math.ceil(num_new_tokens / self.kv_cache.block_size)
new_blocks = self._allocate_blocks(num_new_blocks)
all_blocks = matched_blocks + new_blocks
self.kv_cache.request_to_blocks[request_id] = all_blocks
return {
'num_cached_tokens': num_cached_tokens,
'total_blocks': all_blocks
}
RadixCache for Prefix Sharing
Multiple requests can share KV cache pages for common prefixes:
class RadixCache:
def match_prefix(self, token_ids: List[int]):
current_node = self.root
matched_tokens = []
matched_blocks = []
for token_id in token_ids:
if token_id in current_node.children:
current_node = current_node.children[token_id]
matched_tokens.append(token_id)
if current_node.kv_block_id is not None:
matched_blocks.append(current_node.kv_block_id)
else:
break
return matched_tokens, matched_blocks
Token Generation: Sampling
def sample_tokens(logits: torch.Tensor, sampling_params: List[SamplingParams]):
batch_size, vocab_size = logits.shape
next_tokens = torch.zeros(batch_size, dtype=torch.long)
for i in range(batch_size):
params = sampling_params[i]
batch_logits = logits[i]
# Temperature scaling
if params.temperature != 1.0:
batch_logits = batch_logits / params.temperature
# Top-k filtering
if params.top_k > 0:
top_k_logits, top_k_indices = torch.topk(batch_logits, k=params.top_k)
batch_logits = torch.full_like(batch_logits, float('-inf'))
batch_logits[top_k_indices] = top_k_logits
# Top-p (nucleus) filtering
if params.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(batch_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > params.top_p
batch_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf')
probs = torch.softmax(batch_logits, dim=-1)
next_tokens[i] = torch.multinomial(probs, num_samples=1)
return next_tokens
Conclusion
We've covered the core architecture of SGLang:
- System component overview and initialization
- Request lifecycle from arrival to completion
- Paged KV cache with RadixCache for prefix sharing
- Token generation with various sampling strategies
In Part 2, we'll explore continuous batching, overlap scheduling, and tensor parallelism that make SGLang so fast.