AgentSkillsCN

tlx-api-reference

TLX DSL API 参考文档,专为底层 GPU 原语设计。在编写或修改使用屏障(mbarrier、命名屏障)、内存分配(local_alloc、SMEM、TMEM)、TMA 操作、Warp 特化(async_tasks、async_task)、CLC(集群启动控制)或 wgmma 指令的 TLX 内核代码时使用。全面涵盖 Hopper 和 Blackwell 硬件之间的差异。

SKILL.md
--- frontmatter
name: tlx-api-reference
description: >
  TLX DSL API reference for low-level GPU primitives. Use when writing or
  modifying TLX kernel code that uses barriers (mbarrier, named barriers),
  memory allocation (local_alloc, SMEM, TMEM), TMA operations, warp
  specialization (async_tasks, async_task), CLC (cluster launch control),
  or wgmma instructions. Covers Hopper and Blackwell hardware differences.

TLX API Quick Reference

Warp Specialization

FunctionDescriptionArch
tlx.async_tasks()Context manager wrapping all async task regionsBoth
tlx.async_task([task_ids])Assign code to specific task IDs (e.g., [0] = producer, [1,2] = consumers)Both
tlx.async_task(num_warps=N, num_regs=R)Explicit warp/register allocation for a taskBoth
tlx.async_task("default", num_regs=R)Default task for code outside explicit tasksBoth
tlx.async_task_replica_id()Returns replica ID inside an async regionBoth

Warp specialization skeleton

python
with tlx.async_tasks():
    with tlx.async_task([0]):       # Producer
        # TMA loads
    with tlx.async_task([1, 2]):    # Consumers
        # MMA compute

Memory Barriers

mbarrier (shared-memory allocated)

FunctionDescriptionArch
tlx.alloc_barriers(num_barriers, arrive_count=1)Allocate SMEM barriers and initialize with arrive countBoth
tlx.barrier_expect_bytes(bar, bytes, pred=None)Set expected transaction byte count on barrierBoth
tlx.barrier_wait(bar, phase, pred=None)Wait until barrier phase flips (LOCAL mbarrier only)Both
tlx.barrier_arrive(bar, arrive_count=1, remote_cta_rank=None)Signal arrival at barrierBoth
tlx.cluster_barrier()Full cluster-wide synchronization barrierBoth

arrive_count rules:

  • Implicit arrive from barrier_expect_bytes: use arrive_count=1
  • barrier_arrive inside tlx.async_task: arrive_count = number of warp groups
  • barrier_arrive outside tlx.async_task: arrive_count=1 (only tid==0 arrives)

Named barriers (hardware-allocated, indices 0–15)

FunctionDescriptionArch
tlx.named_barrier_wait(bar_id, num_threads)Wait until num_threads arrive at bar_idNVIDIA
tlx.named_barrier_arrive(bar_id, num_threads)Signal arrival at bar_idNVIDIA

num_threads must be a multiple of 32 (warp size). Typically num_warp_groups * warps_per_group * 32.

Used for PingPong scheduling to prevent tensor core contention between consumer warp groups.

Memory Operations

SMEM / TMEM allocation

FunctionDescriptionArch
tlx.local_alloc(shape, dtype, num, storage=smem, reuse=None, layout=None)Allocate buffered tensor in SMEM or TMEMBoth (TMEM: Blackwell)
tlx.storage_alias_spec(storage=smem, buffer_size_bytes=None)Define shared buffer region for multiple local_alloc calls via reuseBoth
tlx.local_view(buf, index)Get view of a single buffer from a multi-buffered tensorBoth
tlx.local_slice(buf, start, end)Slice a sub-range of a buffered tensorBoth
tlx.subslice(tensor, dim, start, size)Subslice a tensor along a dimensionBoth
tlx.local_load(buf)Load from SMEM/TMEM buffer into registersBoth
tlx.local_store(val, buf)Store from registers into SMEM/TMEM bufferBoth
tlx.local_trans(buf)Transpose a shared memory bufferBoth
tlx.local_reinterpret(buf, dtype)Reinterpret buffer with a different dtypeBoth
tlx.remote_view(buf, remote_cta_rank)Get view of buffer in a remote CTA's SMEMBoth
tlx.remote_shmem_store(val, buf)Store to remote CTA's shared memoryBoth
tlx.async_remote_shmem_store(val, buf)Async store to remote CTA's shared memoryBoth
tlx.tmem_copy(src, dst)Copy between TMEM buffersBlackwell
tlx.fence_async_shared()Memory fence for async shared memory operationsBoth

Storage kinds: tlx.storage_kind.smem, tlx.storage_kind.tmem (Blackwell), tlx.storage_kind.smemCluster

TMA (Tensor Memory Accelerator)

FunctionDescriptionArch
tlx.make_tensor_descriptor(ptr, shape, strides, block_shape)Create TMA descriptor from pointer (host-side)Hopper+
tlx.allocate_tensor_descriptor(ptr, shape, strides, block_shape, swizzle_mode)Allocate and fill TMA descriptor in SMEMHopper+
tlx.reinterpret_tensor_descriptor(desc, dtype)Reinterpret TMA descriptor with different dtypeHopper+
tlx.async_descriptor_load(desc, indices, barrier=None)Async TMA load from global → SMEM, tracked by barrierHopper+
tlx.async_descriptor_store(desc, val, indices)Async TMA store from registers → globalHopper+
tlx.async_descriptor_store_wait()Wait for all pending TMA stores to completeHopper+
tlx.async_load(ptr, buf, barrier)Async bulk copy global → SMEM (cp.async)Hopper+
tlx.async_load_commit_group()Commit async load groupHopper+
tlx.async_load_wait_group(n)Wait for async load groups (n pending allowed)Hopper+

Matrix Multiply (MMA)

FunctionDescriptionArch
tlx.async_dot(A, B, acc=None, use_acc=None, mBarriers=[], two_ctas=False)Warp-group MMA: D = A @ B + C. Maps to wgmma (Hopper) or tcgen05.mma (Blackwell)Both
tlx.async_dot_scaled(A, B, acc, A_scale, A_format, B_scale, B_format, ...)Scaled MMA with FP8 inputs: D = (Ascale_A) @ (Bscale_B) + DBlackwell
tlx.async_dot_wait(pendings, inp)Wait for N pending async dot operations to completeBoth
tlx.tcgen05_commit(mBarrier, two_ctas=False)Make mbarrier track completion of prior tcgen05 ops. Use a SEPARATE mbarrier from async_dotBlackwell

Minimum tile sizes for async_dot: M ≥ 64, K ≥ 16, N ≥ 32

input_precision options: tf32, tf32x3, ieee

CLC (Cluster Launch Control) — Blackwell only

FunctionDescription
tlx.clc_create_context(num_consumers, num_stages=1)Create CLC pipeline context (allocates barriers + response buffers)
tlx.clc_producer(context, p_producer, multi_ctas=False, k=0)Issue CLC try_cancel request from CTA 0
tlx.clc_consumer(context, p_consumer, multi_ctas=False, k=0)Decode tile ID from CLC response, signal completion. Returns tile_id or -1

For 2-CTA mode: set multi_ctas=True (uses "arrive remote, wait local" pattern).

Utility

FunctionDescriptionArch
tlx.cluster_cta_rank()Unique CTA ID within a cluster (all dims)Both
tlx.thread_id(axis)Thread ID along axis 0, 1, or 2Both
tlx.dtype_of(tensor_or_desc)Get element type of tensor or tensor descriptorBoth
tlx.size_of(dtype)Size of dtype in bytesBoth
tlx.get_fp8_format_name(dtype)Get FP8 format string ("e5m2" or "e4m3") for scaled MMABoth
tlx.clock64()64-bit hardware clock value (for timing)Both
tlx.stoch_round(src, dst_ty, rand_bits)Hardware stochastic rounding FP32 → FP8/BF16/F16Blackwell

Common patterns

Producer-consumer with mbarrier (pipelined GEMM)

python
bars_full = tlx.alloc_barriers(num_stages, arrive_count=1)   # TMA arrives implicitly
bars_empty = tlx.alloc_barriers(num_stages, arrive_count=num_consumers)

# Producer: TMA load → signal full
tlx.barrier_expect_bytes(bar_full, nbytes)
tlx.async_descriptor_load(desc, indices, barrier=bar_full)

# Consumer: wait full → MMA → signal empty
tlx.barrier_wait(bar_full, phase)
tlx.async_dot(A, B, acc)
tlx.barrier_arrive(bar_empty)

PingPong with named barriers

python
# Consumer 0 waits for Consumer 1, then issues MMA
tlx.named_barrier_wait(9, 256)   # 256 = 2 warp groups * 4 warps * 32 threads
qk = tlx.async_dot(q, k)
tlx.named_barrier_arrive(10, 256)

# Consumer 1 waits for Consumer 0's MMA to finish
tlx.named_barrier_arrive(9, 256)
tlx.named_barrier_wait(10, 256)
qk = tlx.async_dot(q, k)

Deep-dive docs

  • API reference: third_party/tlx/README.md
  • Barriers: third_party/tlx/doc/tlx_barriers.md
  • Placeholder layouts: third_party/tlx/doc/PlaceholderLayouts.md
  • Storage alias design: third_party/tlx/doc/storage_alias_spec_design.md