Mani Pal

Engineer-researcher

Mani Pal

LLM systems, CUDA kernels, inference optimization, compression, interpretability, and distributed AI infrastructure.

System case study / 2026

FlashAttention-2 CUDA Kernel

Custom IO-aware GPU attention engine

activeCUDA C++ / PyTorch C++ Extension / Triton / NVIDIA Nsight Compute
CUDAAttentionKernel EngineeringInference Optimization

Throughput

2.1x

Versus PyTorch SDPA on A100, seqlen 4096.

Tile

64x64

Q block and K/V block shape.

Memory

O(N)

Streaming residency rather than O(N^2) attention storage.

Motivation

Rebuild attention from the memory hierarchy upward and understand exactly where framework kernels spend bandwidth, registers, shared memory, and occupancy.

Design Constraints

  • Preserve numerical agreement against PyTorch reference paths.
  • Reduce HBM reads and writes by keeping tiled QK, online softmax, and PV accumulation SRAM-resident.
  • Support causal masking and sequence length 4096 benchmark settings.
  • Package as a standalone PyTorch C++ extension usable from Python.

System Architecture

  • Thread-block tiling with Br=64 and Bc=64.
  • Register-level running max, denominator, and output accumulators for online softmax.
  • Fused QK score calculation, causal mask application, exponent rescaling, and V accumulation.
  • Triton cross-check kernel used as an implementation sanity path.

Performance Bottlenecks

  • Shared memory pressure from tile staging.
  • Register pressure during online softmax state updates.
  • Warp-level divergence at causal boundaries.
  • HBM bandwidth from naive O(N^2) attention materialization.

Optimization Decisions

  • Avoid materializing the attention matrix.
  • Use online softmax to maintain stable row-wise normalization across K/V blocks.
  • Tune block sizes against occupancy instead of maximizing tile size blindly.
  • Profile memory throughput, achieved occupancy, and register spills in Nsight Compute.

Benchmark Methodology

  • Compared against torch.scaled_dot_product_attention.
  • Benchmarked on A100 at sequence length 4096.
  • Validated float32 precision against PyTorch reference outputs.
  • Recorded warmup, synchronized CUDA timing, and median throughput.

Results

  • Reached 2.1x throughput over PyTorch SDPA in the profiled A100 setting.
  • Reduced HBM bandwidth from O(N^2) materialization to tiled streaming attention.
  • Produced an installable PyTorch extension and reproducible kernel write-up.

Lessons Learned

  • The core speedup is not only fusion; it is avoiding global-memory traffic.
  • Online softmax correctness is the fragile center of the implementation.
  • Register pressure can erase theoretical tiling wins unless profiled early.