Extract Kernel Definitions
Extract kernel schemas and definitions from SGLang model implementations, with deduplication, and add them to ./flashinfer_trace/ with vanilla Python reference implementations.
Description
This skill analyzes SGLang model implementations to extract the complete set of GPU kernels used during inference. It identifies kernel types (MLA, MOE, GQA, RMSNorm, GEMM), extracts their parameters, generates Definition JSON schemas with Python reference implementations, and handles deduplication across multiple models.
Usage
# Extract kernels from DeepSeek V3 model (MLA + MoE) /extract-kernel-definitions --model-name deepseek_v3 # Extract from Llama (GQA) /extract-kernel-definitions --model-name llama # Extract from multiple models with automatic deduplication /extract-kernel-definitions --model-name deepseek_v3 /extract-kernel-definitions --model-name llama /extract-kernel-definitions --model-name qwen2_moe # Extract only prefill or decode kernels /extract-kernel-definitions --model-name llama --execution-modes prefill /extract-kernel-definitions --model-name llama --execution-modes decode
Parameters
- •
model_name(required): Model name to extract kernels from (e.g., "deepseek_v3", "llama", "qwen2_moe") - •
execution_modes(optional): Inference execution modes to analyze (default: ["prefill", "decode"]) - •
include_quantized(optional): Include quantized kernel variants (default: true) - •
deduplicate(optional): Check for and skip existing definitions (default: true)
Prerequisites
Run /clone-repos first to set up the tmp/ directory with SGLang and FlashInfer (the flashinfer_trace/ directory is already part of this repository).
What This Skill Does
Phase 1: Model Analysis
- •
Locate Model Implementation:
- •Search
tmp/sglang/python/sglang/srt/models/{model_name}.py - •Identify model class (e.g.,
DeepseekV3ForCausalLM,LlamaForCausalLM) - •Parse model architecture from config
- •Search
- •
Identify Layer Components:
- •Attention mechanism (GQA, MHA, MLA)
- •MLP/FFN structure (Dense, MoE)
- •Normalization layers (RMSNorm, LayerNorm)
- •Embedding and output projections
- •
Extract Execution Paths:
- •Prefill path (batch processing, variable sequence length)
- •Decode path (single token, paged KV cache)
Phase 2: Kernel Extraction
For each layer component, extract:
Attention Kernels
- •GQA:
gqa_paged_decode,gqa_paged_prefill,gqa_ragged_prefill - •MLA:
mla_paged_decode,mla_paged_prefill - •Parameters: num_heads, num_kv_heads, head_dim, page_size, ckv_dim, kpe_dim
MoE Kernels
- •Expert routing and selection
- •Expert execution (FP8 quantized, block-scaled)
- •Parameters: num_experts, topk, hidden_size, intermediate_size, group_size
Normalization Kernels
- •
rmsnorm_h{hidden_size} - •
fused_add_rmsnorm_h{hidden_size} - •Parameters: hidden_size, epsilon
GEMM Kernels
- •QKV projection, O projection
- •Gate/Up projection, Down projection
- •Parameters: M (variable), N (output dim), K (input dim)
Phase 3: Deduplication
- •
Load Existing Definitions:
- •Scan
flashinfer_trace/definitions/for existing JSONs - •Build index of definition names and signatures
- •Scan
- •
Compare Extracted Kernels:
- •Check if kernel with same name exists
- •Verify parameter compatibility
- •Skip existing definitions, report what was skipped
- •
Handle Shared Kernels:
- •Kernels like
rmsnorm_h4096may be used by multiple models - •Only create once, add model tags to existing definitions
- •Kernels like
Phase 4: Definition Generation
For each new kernel, generate a Definition JSON following the standards below.
Definition JSON Standards
Naming Convention
Follow the pattern: {op_type}_{variant}_{key_params}
Parameter Abbreviations:
- •
h= num_heads or hidden_size (context-dependent) - •
kv= num_kv_heads - •
d= head_dim - •
ps= page_size - •
ckv= compressed_kv_dim - •
kpe= key_positional_encoding_dim - •
e= num_experts - •
i= intermediate_size - •
topk= top_k_experts - •
ng= n_group - •
kg= topk_group - •
v= vocab_size
Examples by op_type:
- •Attention:
gqa_paged_decode_h32_kv8_d128_ps1,mla_paged_prefill_h16_ckv512_kpe64_ps1 - •MoE:
moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048 - •Normalization:
rmsnorm_h4096,fused_add_rmsnorm_h7168 - •Sampling:
top_k_sampling_from_probs_v129280
Tag Patterns
Tags follow the pattern {category}:{value}:
| Category | Values | Description |
|---|---|---|
status | verified, unverified | Whether reference implementation is validated |
stage | decode, prefill | Inference execution mode |
model | deepseek-v3, deepseek-r1, llama-3.1-8b, etc. | Associated model(s) |
quantization | float8_e4m3fn, nvfp4, int8, int4 | Quantization format |
routing | pre-computed, on-the-fly | For MoE routing type |
Example tags array:
"tags": [ "stage:decode", "status:verified", "model:deepseek-v3", "model:deepseek-r1", "quantization:float8_e4m3fn" ]
Axes Structure
"axes": {
"batch_size": {
"type": "var",
"description": "Batch size (number of sequences)"
},
"num_qo_heads": {
"type": "const",
"value": 16,
"description": "Number of query heads after tensor parallel split (128/8=16)."
}
}
Rules:
- •Variable axes (
type: "var"): runtime dimensions like batch_size, seq_len, num_pages - •Constant axes (
type: "const"): model-specific values withvaluefield - •Always include
descriptionfor complex or model-specific axes
Constraints Field
Add constraints for input validation:
"constraints": [ "len_indptr == batch_size + 1", "num_kv_indices == kv_indptr[-1].item()" ]
Inputs/Outputs Structure
"inputs": {
"tensor_name": {
"shape": ["axis1", "axis2", "axis3"],
"dtype": "bfloat16",
"description": "Description of the tensor"
},
"scalar_input": {
"shape": null,
"dtype": "float32",
"description": "Scalar parameter"
}
}
dtype values: float32, float16, bfloat16, float8_e4m3fn, float8_e5m2, int32, int64
Reference Implementation
- •
Generate Name: Follow naming convention
{op_type}_{variant}_{params}- •Example:
mla_paged_decode_h16_ckv512_kpe64_ps1 - •Example:
moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048
- •Example:
- •
Define Axes:
- •Constant axes: fixed at compile time (e.g., hidden_size, num_heads)
- •Variable axes: determined at runtime (e.g., batch_size, seq_len)
- •
Specify Inputs/Outputs:
- •Tensor shapes using axis names
- •Data types (float16, bfloat16, float8_e4m3fn, etc.)
- •
Write Reference Implementation:
- •ALWAYS prioritize FlashInfer tests (
tmp/flashinfer/tests/) for reference implementations - •Look for vanilla PyTorch reference functions (e.g.,
ref_attention(),ref_rmsnorm()) - •Only use SGLang vanilla implementation when FlashInfer doesn't have the kernel
- •Plain PyTorch implementation with step-by-step computation
- •See "Reference Implementation Sources" section below
- •ALWAYS prioritize FlashInfer tests (
Phase 5: Save Definitions
Output to flashinfer_trace/definitions/{op_type}/{definition_name}.json
Output Format
Definition JSON Example (MLA Decode)
{
"name": "mla_paged_decode_h16_ckv512_kpe64_ps1",
"description": "Batched Multi-head Latent Attention decode with a paged KV cache. Captured from DeepSeek-V3 with tensor parallel size 8.",
"op_type": "mla_paged",
"tags": [
"stage:decode",
"status:verified",
"model:deepseek-v3",
"model:deepseek-r1"
],
"axes": {
"batch_size": { "type": "var" },
"num_qo_heads": {
"type": "const",
"value": 16,
"description": "Number of query heads after tensor parallel split (128/8=16)."
},
"head_dim_ckv": { "type": "const", "value": 512 },
"head_dim_kpe": { "type": "const", "value": 64 },
"page_size": { "type": "const", "value": 1 },
"num_pages": { "type": "var", "description": "Total number of allocated pages in the KV cache." },
"len_indptr": { "type": "var", "description": "Length of kv_indptr array." },
"num_kv_indices": { "type": "var", "description": "Total number of KV page indices." }
},
"constraints": [
"len_indptr == batch_size + 1",
"num_kv_indices == kv_indptr[-1].item()"
],
"inputs": {
"q_nope": {
"shape": ["batch_size", "num_qo_heads", "head_dim_ckv"],
"dtype": "bfloat16",
"description": "Query tensor without positional encoding component."
},
"q_pe": {
"shape": ["batch_size", "num_qo_heads", "head_dim_kpe"],
"dtype": "bfloat16",
"description": "Query positional encoding component."
},
"ckv_cache": {
"shape": ["num_pages", "page_size", "head_dim_ckv"],
"dtype": "bfloat16",
"description": "Compressed key-value cache."
},
"kpe_cache": {
"shape": ["num_pages", "page_size", "head_dim_kpe"],
"dtype": "bfloat16",
"description": "Key positional encoding cache."
},
"kv_indptr": {
"shape": ["len_indptr"],
"dtype": "int32",
"description": "KV page offsets for each sequence."
},
"kv_indices": {
"shape": ["num_kv_indices"],
"dtype": "int32",
"description": "Page indices for KV cache lookups."
},
"sm_scale": {
"shape": null,
"dtype": "float32",
"description": "Softmax scale. Default is (1/sqrt(128 + 64) = 1/sqrt(192))."
}
},
"outputs": {
"output": { "shape": ["batch_size", "num_qo_heads", "head_dim_ckv"], "dtype": "bfloat16" },
"lse": {
"shape": ["batch_size", "num_qo_heads"],
"dtype": "float32",
"description": "The 2-based log-sum-exp of attention logits."
}
},
"reference": "import math\nimport torch\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, sm_scale):\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n ..."
}
Definition JSON Example (MoE)
{
"name": "moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048",
"description": "FP8 block scale MoE operation. Routing and two grouped-GEMM included.",
"op_type": "moe",
"tags": [
"status:verified",
"model:deepseek-v3",
"model:deepseek-r1",
"quantization:float8_e4m3fn"
],
"axes": {
"seq_len": { "type": "var", "description": "Sequence length (number of tokens)" },
"num_experts": {
"type": "const",
"value": 256,
"description": "Total number of experts."
},
"num_local_experts": {
"type": "const",
"value": 32,
"description": "Number of local experts with EP size 8."
},
"hidden_size": {
"type": "const",
"value": 7168,
"description": "Hidden dimension size."
},
"intermediate_size": {
"type": "const",
"value": 2048,
"description": "MoE intermediate layer size."
}
},
"inputs": {
"routing_logits": {
"shape": ["seq_len", "num_experts"],
"dtype": "float32",
"description": "Tensor of routing logits for expert selection"
},
"hidden_states": {
"shape": ["seq_len", "hidden_size"],
"dtype": "float8_e4m3fn",
"description": "Input hidden states tensor (FP8 quantized)"
},
"local_expert_offset": {
"shape": null,
"dtype": "int32",
"description": "Offset of local experts in global expert space."
},
"routed_scaling_factor": {
"shape": null,
"dtype": "float32",
"description": "Scaling factor for routing weights."
}
},
"outputs": {
"output": {
"shape": ["seq_len", "hidden_size"],
"dtype": "bfloat16",
"description": "Final MoE output tensor"
}
},
"reference": "import torch\n\n@torch.no_grad()\ndef run(routing_logits, hidden_states, ...):\n # Check constants\n assert H == 7168, 'hidden_size must be 7168'\n ..."
}
Implementation Steps
When executing this skill:
- •
Locate model file:
bashls tmp/sglang/python/sglang/srt/models/ | grep -i {model_name} - •
Read model implementation:
- •Parse the Python file
- •Identify forward() methods
- •Extract kernel calls (attention, MoE, norm, GEMM)
- •
Extract kernel parameters from model config:
- •Look for config class (e.g.,
DeepseekV3Config) - •Extract: num_hidden_layers, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, etc.
- •Look for config class (e.g.,
- •
Check existing definitions:
bashls flashinfer_trace/definitions/*/
- •
Generate new definition JSONs:
- •Create directory if needed:
mkdir -p flashinfer_trace/definitions/{op_type}/ - •Write JSON file with reference implementation
- •Create directory if needed:
- •
Report results:
- •List new definitions created
- •List existing definitions skipped (deduplication)
- •List kernels shared across models
SGLang Code Patterns to Look For
Attention Kernel Calls
# GQA decode (paged) flashinfer.batch_decode_with_paged_kv_cache(...) # MLA decode from sglang.srt.layers.attention.mla_decode import mla_decode_attention
MoE Kernel Calls
# DeepSeek MoE from sglang.srt.layers.moe.fused_moe import fused_moe fused_moe(hidden_states, w1, w2, w3, topk_weights, topk_ids, ...)
Normalization
# RMSNorm from sglang.srt.layers.layernorm import RMSNorm self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
Ground Truth Hierarchy
When extracting kernel definitions, use different sources for different purposes:
For Model Constants: SGLang Model Config (Required)
- •Location:
tmp/sglang/python/sglang/srt/models/{model_name}.py - •Use for: Extracting model-specific constant values
- •Examples:
- •
num_attention_heads,num_key_value_heads,head_dim - •
hidden_size,intermediate_size - •
num_experts,topk,n_group,topk_group - •
page_size(from SGLang's paged attention configuration)
- •
- •Important: Always align constants with SGLang's model config to ensure compatibility
For Reference run() Implementation: FlashInfer Unit Tests (PRIMARY - ALWAYS CHECK FIRST)
- •When: FlashInfer has a kernel implementation and corresponding unit test (CHECK THIS FIRST)
- •Location:
tmp/flashinfer/tests/ - •Why: FlashInfer tests contain vanilla PyTorch implementations that serve as ground truth
- •Examples:
code
tests/test_batch_decode.py # GQA decode reference tests/test_batch_prefill.py # GQA prefill reference tests/test_norm.py # RMSNorm reference tests/test_mla.py # MLA reference
- •Pattern: Look for functions like
ref_attention(),ref_rmsnorm(),ref_mla(), etc. - •Use for: Writing the
referencefield in Definition JSON - •IMPORTANT: Always search FlashInfer tests directory FIRST before falling back to SGLang
For Reference run() Implementation: SGLang Vanilla (FALLBACK ONLY)
- •When: FlashInfer does NOT have the kernel (e.g., some MoE variants, custom ops)
- •Location:
tmp/sglang/python/sglang/srt/layers/ - •Examples:
code
layers/moe/fused_moe.py # MoE vanilla forward (when FlashInfer MoE not available) layers/attention/triton_ops/ # Custom attention implementations
- •Important: Only use SGLang vanilla implementation when FlashInfer doesn't have the kernel
- •How to verify: Check
tmp/flashinfer/tests/first to confirm kernel is not available
For API Signature: FlashInfer Python API
- •When: Determining input/output tensor specifications
- •Location:
tmp/flashinfer/python/flashinfer/ - •Examples:
- •
flashinfer.attention.batch_decode_with_paged_kv_cache - •
flashinfer.norm.rmsnorm - •
flashinfer.mla.BatchMLAPagedAttentionWrapper
- •
- •Use for: Input/output shape and dtype specifications in Definition JSON
Reference Implementation Sources
The reference field in Definition JSON contains a run() function. Always prioritize FlashInfer unit tests over SGLang implementations.
Primary Source: FlashInfer Unit Tests (REQUIRED - CHECK THIS FIRST)
- •
Location:
tmp/flashinfer/tests/ - •
Why: FlashInfer tests contain ground-truth vanilla PyTorch implementations
- •
How to find: Search for reference functions in test files
bash# Search for reference implementations grep -r "def ref_" tmp/flashinfer/tests/ grep -r "def reference" tmp/flashinfer/tests/
- •
Kernel type to test file mapping:
Kernel Type Test File Reference Function GQA decode test_batch_decode.pyref_attention()or inline referenceGQA prefill test_batch_prefill.pyref_attention()or inline referenceMLA test_mla.pyref_mla()or inline referenceRMSNorm test_norm.pyref_rmsnorm()orref_fused_add_rmsnorm()Sampling test_sampling.pyReference sampling implementations - •
Important: The reference
run()should match FlashInfer's test implementation exactly, with constant values updated to match SGLang model config
Fallback Source: SGLang Vanilla Implementation (ONLY when FlashInfer unavailable)
- •When to use: ONLY when FlashInfer does NOT have a unit test for this kernel
- •Location:
tmp/sglang/python/sglang/srt/layers/ - •Common cases requiring SGLang fallback:
- •Custom MoE implementations:
layers/moe/fused_moe.py - •Model-specific attention variants not in FlashInfer
- •Quantized kernels not yet supported by FlashInfer
- •Custom MoE implementations:
- •How to check if FlashInfer has the kernel:
bash
# Check if FlashInfer has a test for this kernel type ls tmp/flashinfer/tests/test_*.py | xargs grep -l "your_kernel_name" # Check if FlashInfer has the API grep -r "def your_kernel" tmp/flashinfer/python/flashinfer/
Reference Implementation Guidelines
- •Pure PyTorch: Use only
torchoperations, no external kernels - •Step-by-step: Break down computation into clear steps
- •Match signatures: Input/output names must match definition schema
- •Include all outputs: Return all outputs specified in definition (e.g., both
outputandlsefor attention)
Example reference implementation pattern:
import torch
import math
def run(q, k, v, ...):
# Step 1: Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) * sm_scale
# Step 2: Apply softmax
attn_weights = torch.softmax(scores, dim=-1)
# Step 3: Compute output
output = torch.matmul(attn_weights, v)
# Step 4: Compute log-sum-exp (if needed)
lse = torch.logsumexp(scores, dim=-1)
return output, lse
Kernel Type to Model Mapping
| Model | Attention | MLP | Normalization |
|---|---|---|---|
| DeepSeek V3/R1 | MLA | MoE | RMSNorm |
| Llama 3.x | GQA | Dense | RMSNorm |
| Qwen2 MoE | GQA | MoE | RMSNorm |
| Mixtral | GQA | MoE | RMSNorm |
Error Handling
Model Not Found
- •Error: Model file doesn't exist in SGLang
- •Handling: List available models, suggest closest match
Unsupported Kernel Type
- •Error: Unknown kernel pattern in model
- •Handling: Log warning, skip kernel, report for manual review
Duplicate with Conflict
- •Error: Definition exists with different parameters
- •Handling: Create new versioned definition, flag for review
Integration with Other Skills
# Complete workflow /clone-repos # Extract from multiple models /extract-kernel-definitions --model-name deepseek_v3 /extract-kernel-definitions --model-name llama /extract-kernel-definitions --model-name qwen2_moe # Add tests for new definitions /add-reference-tests --op-type mla_paged /add-reference-tests --op-type moe
Notes
- •Each kernel may have multiple variants for different execution modes (prefill/decode)
- •Quantized kernels are treated as separate definitions
- •Reference implementations prioritize clarity over performance
- •Model tags enable tracing which models use each kernel
- •Deduplication is essential for maintaining a clean dataset
Maintaining This Document
Update this file when adding new op_types, changing Definition JSON schema, or modifying extraction patterns.