System case study / 2026
FlashAttention-2 CUDA Kernel
Custom IO-aware GPU attention engine
activeCUDA C++ / PyTorch C++ Extension / Triton / NVIDIA Nsight Compute
CUDAAttentionKernel EngineeringInference Optimization
Throughput
2.1x
Versus PyTorch SDPA on A100, seqlen 4096.
Tile
64x64
Q block and K/V block shape.
Memory
O(N)
Streaming residency rather than O(N^2) attention storage.
Motivation
Rebuild attention from the memory hierarchy upward and understand exactly where framework kernels spend bandwidth, registers, shared memory, and occupancy.
Design Constraints
- Preserve numerical agreement against PyTorch reference paths.
- Reduce HBM reads and writes by keeping tiled QK, online softmax, and PV accumulation SRAM-resident.
- Support causal masking and sequence length 4096 benchmark settings.
- Package as a standalone PyTorch C++ extension usable from Python.
System Architecture
- Thread-block tiling with Br=64 and Bc=64.
- Register-level running max, denominator, and output accumulators for online softmax.
- Fused QK score calculation, causal mask application, exponent rescaling, and V accumulation.
- Triton cross-check kernel used as an implementation sanity path.
Performance Bottlenecks
- Shared memory pressure from tile staging.
- Register pressure during online softmax state updates.
- Warp-level divergence at causal boundaries.
- HBM bandwidth from naive O(N^2) attention materialization.
Optimization Decisions
- Avoid materializing the attention matrix.
- Use online softmax to maintain stable row-wise normalization across K/V blocks.
- Tune block sizes against occupancy instead of maximizing tile size blindly.
- Profile memory throughput, achieved occupancy, and register spills in Nsight Compute.
Benchmark Methodology
- Compared against torch.scaled_dot_product_attention.
- Benchmarked on A100 at sequence length 4096.
- Validated float32 precision against PyTorch reference outputs.
- Recorded warmup, synchronized CUDA timing, and median throughput.
Results
- Reached 2.1x throughput over PyTorch SDPA in the profiled A100 setting.
- Reduced HBM bandwidth from O(N^2) materialization to tiled streaming attention.
- Produced an installable PyTorch extension and reproducible kernel write-up.
Lessons Learned
- The core speedup is not only fusion; it is avoiding global-memory traffic.
- Online softmax correctness is the fragile center of the implementation.
- Register pressure can erase theoretical tiling wins unless profiled early.