From 40755feac94bab62a4a3eaec01b83bee94d7200c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E8=B6=85?= Date: Tue, 23 Dec 2025 20:00:47 +0800 Subject: [PATCH 1/4] flashattention-2 fwd and bwd support (not support causal and gqa) --- test/bench_attention.py | 96 ++++++++++----- test/kernels/attention.py | 245 +++++++++++++++++++++++++++++++++++--- 2 files changed, 297 insertions(+), 44 deletions(-) diff --git a/test/bench_attention.py b/test/bench_attention.py index 9350be1..31652e2 100644 --- a/test/bench_attention.py +++ b/test/bench_attention.py @@ -12,8 +12,8 @@ import torch from util import estimate_bench_iter -from kernels.attention import fmha_kernel - +from kernels.attention import fmha_kernel, fmha_bwd_dq_kernel, fmha_bwd_dk_dv_kernel, fmha_bwd_preprocess_kernel +has_backward = True def qkv_id(qkv_shape: tuple[tuple[int, ...], tuple[int, ...]]) -> str: q_shape, kv_shape = qkv_shape @@ -29,21 +29,22 @@ def qkv_id(qkv_shape: tuple[tuple[int, ...], tuple[int, ...]]) -> str: @pytest.fixture( params=[ # B, H, L, D - ((1, 32, 1024, 128), (1, 32, 1024, 128)), # prefill - ((1, 32, 1024, 128), (1, 8, 1024, 128)), # prefill + gqa - ((1, 32, 8192, 128), (1, 32, 8192, 128)), # prefill - ((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa - ((1, 32, 1, 128), (1, 32, 1024, 128)), # decode - ((8, 32, 1, 128), (8, 32, 1024, 128)), # decode - ((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa - ((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa + ((1, 32, 1024, 128), (1, 32, 1024, 128)), # prefill + # ((1, 32, 1024, 128), (1, 8, 1024, 128)), # prefill + gqa + ((1, 32, 8192, 64), (1, 32, 8192, 64)), + ((1, 32, 8192, 128), (1, 32, 8192, 128)), +# ((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa +# ((1, 32, 1, 128), (1, 32, 1024, 128)), # decode +# ((8, 32, 1, 128), (8, 32, 1024, 128)), # decode +# ((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa +# ((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa ], ids=qkv_id) def qkv_shape(request): return request.param -@pytest.fixture(params=[torch.float16, torch.bfloat16], ids=dtype_id) +@pytest.fixture(params=[torch.bfloat16], ids=dtype_id) def dtype(request): return request.param @@ -52,25 +53,37 @@ def dtype(request): def bench_fmha(qkv_shape, dtype, backend, benchmark): q_shape, kv_shape = qkv_shape - q = torch.randn(q_shape, dtype=dtype, device='cuda') - k = torch.randn(kv_shape, dtype=dtype, device='cuda') - v = torch.randn(kv_shape, dtype=dtype, device='cuda') - o = torch.empty_like(q) - ref = torch.empty_like(q) + q = torch.randn(q_shape, dtype=dtype, device='cuda', requires_grad=True) + k = torch.randn(kv_shape, dtype=dtype, device='cuda', requires_grad=True) + v = torch.randn(kv_shape, dtype=dtype, device='cuda', requires_grad=True) + lse = torch.randn(q_shape[:-1], dtype=torch.float32, device='cuda') + grad = torch.randn(q_shape, dtype=dtype, device='cuda', requires_grad=True) + o = torch.zeros_like(q) + ref = torch.zeros_like(q) is_causal = q_shape[2] == kv_shape[2] + is_causal = False enable_gqa = q_shape[1] != kv_shape[1] - backend(q, k, v, o, is_causal, enable_gqa) - ref_fmha(q, k, v, ref, is_causal, enable_gqa) + if has_backward: + dq, dk, dv = backend(q, k, v, o, grad, lse, is_causal, enable_gqa) + dq_ref, dk_ref, dv_ref = ref_fmha(q, k, v, ref, grad, is_causal, enable_gqa) + torch.testing.assert_close(o, ref, atol=1e-2, rtol=5e-2) + torch.testing.assert_close(dq, dq_ref, atol=1e-2, rtol=5e-2) + torch.testing.assert_close(dk, dk_ref, atol=1e-2, rtol=5e-2) + torch.testing.assert_close(dv, dv_ref, atol=1e-2, rtol=5e-2) + else: + backend(q, k, v, o, grad, lse, is_causal, enable_gqa) + ref_fmha(q, k, v, ref, grad, is_causal, enable_gqa) + torch.testing.assert_close(o, ref, atol=1e-2, rtol=5e-2) torch.cuda.synchronize() warmup_rounds, iterations, rounds = estimate_bench_iter( - backend, (q, k, v, o, is_causal, enable_gqa), + backend, (q, k, v, o, grad, lse, is_causal, enable_gqa), ) benchmark.pedantic( - backend, (q, k, v, o, is_causal, enable_gqa), + backend, (q, k, v, o, grad, lse, is_causal, enable_gqa), rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations, ) @@ -87,7 +100,7 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark): benchmark.extra_info['bytes_rw'] = bytes_rw -def cutile_fmha(q, k, v, o, is_causal, enable_gqa): +def cutile_fmha(q, k, v, o, grad, lse, is_causal, enable_gqa): b, qh, q_len, d = q.shape _, kh, k_len, _ = k.shape qk_scale = 1 / sqrt(d) @@ -97,15 +110,30 @@ def cutile_fmha(q, k, v, o, is_causal, enable_gqa): input_pos = 0 if q_len == k_len else (k_len - 1) EVEN_K = (k_len % TILE_N) == 0 ct.launch(torch.cuda.current_stream(), grid, fmha_kernel, - (q, k, v, o, - qk_scale, - input_pos, - d, qh, - TILE_M, TILE_N, + (q, k, v, o, lse, + qk_scale, input_pos, d, qh, TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)) - - -def torch_fmha(q, k, v, o, is_causal, enable_gqa): + if has_backward: + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + delta = torch.zeros_like(lse) + ct.launch(torch.cuda.current_stream(), (ceil(q_len / TILE_M), b * qh, 1), + fmha_bwd_preprocess_kernel, + (o, grad, delta, qh, TILE_M, d)) + ct.launch(torch.cuda.current_stream(), (ceil(k_len / TILE_N), b * qh, 1), + fmha_bwd_dk_dv_kernel, + (q, k, v, grad, delta, lse, dk, dv, + qk_scale, input_pos, d, qh, TILE_M, TILE_N, + query_group_size, is_causal, EVEN_K)) + ct.launch(torch.cuda.current_stream(), (ceil(q_len / TILE_M), b * qh, 1), + fmha_bwd_dq_kernel, + (q, k, v, grad, delta, lse, dq, + qk_scale, input_pos, d, qh, TILE_M, TILE_N, + query_group_size, is_causal, EVEN_K)) + return dq, dk, dv + +def torch_fmha(q, k, v, o, grad, lse, is_causal, enable_gqa): backend = SDPBackend.CUDNN_ATTENTION \ if (q.shape[2] == k.shape[2]) \ else SDPBackend.FLASH_ATTENTION @@ -114,11 +142,19 @@ def torch_fmha(q, k, v, o, is_causal, enable_gqa): is_causal=is_causal, enable_gqa=enable_gqa) o.copy_(ret) + if has_backward: + ret.backward(grad, retain_graph=True) + dq, dk, dv = q.grad, k.grad, v.grad + return dq, dk, dv -def ref_fmha(q, k, v, o, is_causal, enable_gqa): +def ref_fmha(q, k, v, o, grad, is_causal, enable_gqa): with sdpa_kernel(SDPBackend.FLASH_ATTENTION): ret = scaled_dot_product_attention(q, k, v, is_causal=is_causal, enable_gqa=enable_gqa) o.copy_(ret) + if has_backward: + ret.backward(grad, retain_graph=True) + dq, dk, dv = q.grad, k.grad, v.grad + return dq, dk, dv \ No newline at end of file diff --git a/test/kernels/attention.py b/test/kernels/attention.py index fd431b8..f76b2d4 100644 --- a/test/kernels/attention.py +++ b/test/kernels/attention.py @@ -7,6 +7,7 @@ import math from cuda.tile import RoundingMode as RMd +from cuda.tile import kernel, ByTarget INV_LOG_2 = 1.0 / math.log(2) @@ -14,20 +15,13 @@ # Define type aliases for Constant integers and booleans ConstInt = ct.Constant[int] ConstBool = ct.Constant[bool] - +allow_tma = True # --- FMHA Kernel Implementation --- @ct.kernel(occupancy=2) -def fmha_kernel(Q, K, V, Out, - qk_scale: float, - input_pos: int, - TILE_D: ConstInt, # TILE_D = hidden_size - H: ConstInt, - TILE_M: ConstInt, - TILE_N: ConstInt, - QUERY_GROUP_SIZE: ConstInt, - CAUSAL: ConstBool, - EVEN_K: ConstBool): +def fmha_kernel(Q, K, V, Out, Lse, qk_scale: float, input_pos: int, TILE_D: ConstInt, + H: ConstInt, TILE_M: ConstInt, TILE_N: ConstInt, QUERY_GROUP_SIZE: ConstInt, + CAUSAL: ConstBool, EVEN_K: ConstBool): """ cuTile kernel for Fused Multi-Head Attention (FMHA). Computes attention output for a specific batch item and head, using tiling and online softmax. @@ -55,10 +49,10 @@ def fmha_kernel(Q, K, V, Out, m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32) l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32) acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) - + lse = ct.full((TILE_M, 1), 0.0, dtype=np.float32) # Load query tile for this batch, head, and M-chunk q = ct.load( - Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D) + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D), allow_tma=allow_tma ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] # loop over k, v and update accumulator @@ -81,6 +75,7 @@ def fmha_kernel(Q, K, V, Out, K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), latency=2, + allow_tma=allow_tma ) k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) @@ -89,7 +84,7 @@ def fmha_kernel(Q, K, V, Out, # --- Apply Causal Masking --- if (CAUSAL or not EVEN_K) and j >= mask_start: offs_n = j * TILE_N + offs_n_tile - mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool) + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) # out of bound mask if not EVEN_K: mask = mask & (offs_n < k_seqlen) @@ -117,12 +112,234 @@ def fmha_kernel(Q, K, V, Out, v = ct.load( V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D), latency=4, + allow_tma=allow_tma ).reshape((TILE_N, TILE_D)) # [TILE_N, TILE_D] p = p.astype(Q.dtype) acc = ct.mma(p, v, acc) # [TILE_M, TILE_N] m_i = m_ij # [TILE_M, 1] # --- Final Normalization and Store --- + lse = m_i + ct.log2(l_i) + lse = lse.reshape((1, 1, TILE_M)) acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX) acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype) + ct.store(Lse, index=(batch_idx, head_idx, bid_x), tile=lse) ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc) + + +@ct.kernel(occupancy=2) +def fmha_bwd_dq_kernel(Q, K, V, Grad, Delta, Lse, DQ, qk_scale: float, input_pos: int, TILE_D: ConstInt, + H: ConstInt, TILE_M: ConstInt, TILE_N: ConstInt, QUERY_GROUP_SIZE: ConstInt, + CAUSAL: ConstBool, EVEN_K: ConstBool): + # Map block IDs to batch and head indices + bid_x = ct.bid(0) + bid_y = ct.bid(1) + batch_idx = bid_y // H + head_idx = bid_y % H + off_kv_h = head_idx // QUERY_GROUP_SIZE + + # Initialize offsets for current query tile (M-dimension) + offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32) # [TILE_M] + offs_m += input_pos + offs_m = offs_m[:, None] # [TILE_M, 1] + + # Initialize local offsets for key/value tile (N-dimension) + offs_n_tile = ct.arange(TILE_N, dtype=np.int32) # [TILE_N] + offs_n_tile = offs_n_tile[None, :] # [1, TILE_N] + + # Load query tile for this batch, head, and M-chunk + q = ct.load( + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D), allow_tma=allow_tma + ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D + lse_i = ct.load(Lse, index=(batch_idx, head_idx, bid_x), shape=(1, 1, TILE_M), + allow_tma=allow_tma).reshape((TILE_M, 1)) + delta_i = ct.load(Delta, index=(batch_idx, head_idx, bid_x), shape=(1, 1, TILE_M), + allow_tma=allow_tma).reshape((TILE_M, 1)) + do = ct.load( + Grad, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D), allow_tma=allow_tma + ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + dq = ct.full((TILE_M, TILE_D), 0., dtype=np.float32) # [TILE_M, TILE_D] + + # loop over k, v and update accumulator + # m_end = input_pos + (bid_x + 1) * TILE_M + k_seqlen = K.shape[2] + # if CAUSAL: + # # when kv pos could exceed q pos + # mask_start = (input_pos + bid_x * TILE_M) // TILE_N + # # when kv pos could exceed k_seqlen + # mask_start = min(mask_start, k_seqlen // TILE_N) + # Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N) + # else: + # Tc = ct.cdiv(k_seqlen, TILE_N) + # mask_start = k_seqlen // TILE_N + Tc = ct.cdiv(k_seqlen, TILE_N) + for j in range(0, Tc): + # --- Compute QK product --- + k = ct.load( + K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + allow_tma=allow_tma + ) + k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + + # # --- Apply Causal Masking --- + # if (CAUSAL or not EVEN_K) and j >= mask_start: + # offs_n = j * TILE_N + offs_n_tile + # mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) + # # out of bound mask + # if not EVEN_K: + # mask = mask & (offs_n < k_seqlen) + # # causal mask + # if CAUSAL: + # mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + # mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + # qk += mask + + v = ct.load( + V, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=4, + allow_tma=allow_tma + ).reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + + dp = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) # [TILE_M, TILE_N] + dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] + dp = dp - delta_i + + qk = qk * qk_scale * INV_LOG_2 + p = ct.exp2(qk - lse_i, flush_to_zero=True) # [TILE_M, TILE_N] + ds = p * dp # [TILE_M, TILE_N] + ds = ds.astype(k.dtype) + + dskt = ct.full((TILE_M, TILE_D), 0., dtype=np.float32) # [TILE_M, TILE_D] + kt = ct.permute(k, (1, 0)) # [TILE_N, TILE_D] + dskt = ct.mma(ds, kt, dskt) # [TILE_M, TILE_D] + dq = dq + dskt # [TILE_M, TILE_D] + + dq = dq * qk_scale + dq = dq.astype(q.dtype).reshape((1, 1, TILE_M, TILE_D)) + ct.store(DQ, index=(batch_idx, head_idx, bid_x, 0), tile=dq) + + + +@ct.kernel(occupancy=2) +def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, input_pos: int, TILE_D: ConstInt, + H: ConstInt, TILE_M: ConstInt, TILE_N: ConstInt, QUERY_GROUP_SIZE: ConstInt, + CAUSAL: ConstBool, EVEN_K: ConstBool): + bid_x = ct.bid(0) + bid_y = ct.bid(1) + batch_idx = bid_y // H + head_idx = bid_y % H + off_kv_h = head_idx // QUERY_GROUP_SIZE + + # Adjust qk_scale for exp2 + k = ct.load( + K, index=(batch_idx, off_kv_h, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + allow_tma=allow_tma + ).reshape((TILE_D, TILE_N)) + v = ct.load( + V, index=(batch_idx, off_kv_h, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + allow_tma=allow_tma + ).reshape((TILE_D, TILE_N)) + + dk = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] + dv = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] + + # Initialize local offsets for query/key/value tile (M or N-dimension) + offs_m_tile = ct.arange(TILE_M, dtype=np.int32) + input_pos # [TILE_M] + offs_m_tile = offs_m_tile[:, None] # [TILE_M, 1] + offs_n = bid_x * TILE_N + ct.arange(TILE_N, dtype=np.int32) + offs_n = offs_n[None, :] + k_seqlen = K.shape[2] + Tr = ct.cdiv(Q.shape[2], TILE_M) + + # if CAUSAL: + # m_start = bid_x * TILE_N // TILE_M + # else: + # m_start = 0 + for i in range(0, Tr): + q = ct.load(Q, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), + latency=2, allow_tma=allow_tma).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + + # loop over k, v and update accumulator + if CAUSAL: + # when kv pos could exceed q pos + mask_start = (input_pos + i * TILE_M) // TILE_N + # when kv pos could exceed k_seqlen + mask_start = min(mask_start, k_seqlen // TILE_N) + else: + mask_start = k_seqlen // TILE_N + + # # --- Apply Causal Masking --- + # if (CAUSAL or not EVEN_K) and bid_x <= mask_start: + # offs_m = i * TILE_M + offs_m_tile + # mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) + # # out of bound mask + # if not EVEN_K: + # mask = mask & (offs_n < k_seqlen) + # # causal mask + # if CAUSAL: + # mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + # mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + # qk += mask + + lse_i = ct.load(Lse, index=(batch_idx, head_idx, i), shape=(1, 1, TILE_M), + allow_tma=allow_tma).reshape((TILE_M, 1)) + qk = qk * qk_scale * INV_LOG_2 + p = ct.exp2(qk - lse_i, flush_to_zero=True) # [TILE_M, TILE_N] + pt = ct.permute(p, (1, 0)) # [TILE_N, TILE_M] + + do = ct.load(Grad, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), + latency=4, allow_tma=allow_tma).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + pt = pt.astype(do.dtype) + ptdo = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] + ptdo = ct.mma(pt, do, ptdo) # [TILE_N, TILE_D] + dv = dv + ptdo # [TILE_N, TILE_D] + + dp = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) # [TILE_M, TILE_N] + dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] + delta_i = ct.load(Delta, index=(batch_idx, head_idx, i), shape=(1, 1, TILE_M), + allow_tma=allow_tma).reshape((TILE_M, 1)) + dp = dp - delta_i + ds = p * dp # [TILE_M, TILE_N] + dst = ct.permute(ds, (1, 0)) # [TILE_N, TILE_M] + dst = dst.astype(q.dtype) + dstq = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] + dstq = ct.mma(dst, q, dstq) # [TILE_N, TILE_D] + dk = dk + dstq + + dk = dk * qk_scale + dk = dk.astype(k.dtype).reshape((1, 1, TILE_N, TILE_D)) + dv = dv.astype(v.dtype).reshape((1, 1, TILE_N, TILE_D)) + ct.store(DK, index=(batch_idx, head_idx, bid_x, 0), tile=dk) + ct.store(DV, index=(batch_idx, head_idx, bid_x, 0), tile=dv) + + +@ct.kernel(occupancy=1) +def fmha_bwd_preprocess_kernel(O, Grad, Delta, + H: ConstInt, + TILE_M: ConstInt, + TILE_D: ConstInt): + bid_x = ct.bid(0) + bid_y = ct.bid(1) + batch_idx = bid_y // H + head_idx = bid_y % H + o = ct.load(O, index=(batch_idx, head_idx, bid_x, 0), + shape=(1, 1, TILE_M, TILE_D), + latency=2, allow_tma=allow_tma + ).reshape((TILE_M, TILE_D)) + do = ct.load(Grad, index=(batch_idx, head_idx, bid_x, 0), + shape=(1, 1, TILE_M, TILE_D), + latency=2, allow_tma=allow_tma + ).reshape((TILE_M, TILE_D)) + delta = ct.sum(o * do, axis=1).reshape((1, 1, TILE_M)) + ct.store(Delta, index=(batch_idx, head_idx, bid_x), tile=delta) \ No newline at end of file From a47e7c7817fdd16d68c61347467b27bc33a910cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E8=B6=85?= Date: Tue, 30 Dec 2025 14:35:18 +0800 Subject: [PATCH 2/4] flashattention-2 bwd support causal mask --- test/bench_attention.py | 33 ++++---- test/kernels/attention.py | 154 ++++++++++++++++++++------------------ 2 files changed, 99 insertions(+), 88 deletions(-) diff --git a/test/bench_attention.py b/test/bench_attention.py index 31652e2..72894cc 100644 --- a/test/bench_attention.py +++ b/test/bench_attention.py @@ -13,7 +13,7 @@ from util import estimate_bench_iter from kernels.attention import fmha_kernel, fmha_bwd_dq_kernel, fmha_bwd_dk_dv_kernel, fmha_bwd_preprocess_kernel -has_backward = True +has_backward = True def qkv_id(qkv_shape: tuple[tuple[int, ...], tuple[int, ...]]) -> str: q_shape, kv_shape = qkv_shape @@ -29,15 +29,20 @@ def qkv_id(qkv_shape: tuple[tuple[int, ...], tuple[int, ...]]) -> str: @pytest.fixture( params=[ # B, H, L, D - ((1, 32, 1024, 128), (1, 32, 1024, 128)), # prefill + ((6, 24, 67000, 64), (6, 24, 67000, 64)), # tanlan dim 64 + ((6, 24, 67000, 128), (6, 24, 67000, 128)), # tanlan dim 128 + # ((6, 32, 1024, 128), (6, 32, 1024, 128)), # prefill + # ((1, 32, 1024, 64), (1, 32, 1024, 64)), + # ((1, 32, 1024, 64), (1, 8, 1024, 64)), # prefill + gqa # ((1, 32, 1024, 128), (1, 8, 1024, 128)), # prefill + gqa - ((1, 32, 8192, 64), (1, 32, 8192, 64)), - ((1, 32, 8192, 128), (1, 32, 8192, 128)), -# ((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa -# ((1, 32, 1, 128), (1, 32, 1024, 128)), # decode -# ((8, 32, 1, 128), (8, 32, 1024, 128)), # decode -# ((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa -# ((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa + # ((1, 32, 8192, 64), (1, 32, 8192, 64)), + # ((1, 32, 8192, 128), (1, 32, 8192, 128)), + # ((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa + # ((1, 32, 1, 128), (1, 32, 1024, 128)), # decode + # ((1, 32, 1, 64), (1, 32, 1024, 64)), # decode + # ((8, 32, 1, 128), (8, 32, 1024, 128)), # decode + # ((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa + # ((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa ], ids=qkv_id) def qkv_shape(request): @@ -61,21 +66,19 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark): o = torch.zeros_like(q) ref = torch.zeros_like(q) is_causal = q_shape[2] == kv_shape[2] - is_causal = False enable_gqa = q_shape[1] != kv_shape[1] if has_backward: dq, dk, dv = backend(q, k, v, o, grad, lse, is_causal, enable_gqa) dq_ref, dk_ref, dv_ref = ref_fmha(q, k, v, ref, grad, is_causal, enable_gqa) torch.testing.assert_close(o, ref, atol=1e-2, rtol=5e-2) - torch.testing.assert_close(dq, dq_ref, atol=1e-2, rtol=5e-2) - torch.testing.assert_close(dk, dk_ref, atol=1e-2, rtol=5e-2) - torch.testing.assert_close(dv, dv_ref, atol=1e-2, rtol=5e-2) + torch.testing.assert_close(dq, dq_ref, atol=6e-2, rtol=5e-2) + # torch.testing.assert_close(dk, dk_ref, atol=6e-2, rtol=5e-2) + # torch.testing.assert_close(dv, dv_ref, atol=6e-2, rtol=5e-2) else: backend(q, k, v, o, grad, lse, is_causal, enable_gqa) ref_fmha(q, k, v, ref, grad, is_causal, enable_gqa) - - torch.testing.assert_close(o, ref, atol=1e-2, rtol=5e-2) + torch.testing.assert_close(o, ref, atol=1e-2, rtol=5e-2) torch.cuda.synchronize() warmup_rounds, iterations, rounds = estimate_bench_iter( diff --git a/test/kernels/attention.py b/test/kernels/attention.py index f76b2d4..c2a3e92 100644 --- a/test/kernels/attention.py +++ b/test/kernels/attention.py @@ -8,6 +8,7 @@ from cuda.tile import RoundingMode as RMd from cuda.tile import kernel, ByTarget +from cuda.tile._numeric_semantics import PaddingMode INV_LOG_2 = 1.0 / math.log(2) @@ -149,72 +150,99 @@ def fmha_bwd_dq_kernel(Q, K, V, Grad, Delta, Lse, DQ, qk_scale: float, input_pos # Load query tile for this batch, head, and M-chunk q = ct.load( - Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D), allow_tma=allow_tma + Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D), allow_tma=allow_tma, padding_mode=PaddingMode.ZERO ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D lse_i = ct.load(Lse, index=(batch_idx, head_idx, bid_x), shape=(1, 1, TILE_M), allow_tma=allow_tma).reshape((TILE_M, 1)) delta_i = ct.load(Delta, index=(batch_idx, head_idx, bid_x), shape=(1, 1, TILE_M), allow_tma=allow_tma).reshape((TILE_M, 1)) do = ct.load( - Grad, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D), allow_tma=allow_tma + Grad, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D), allow_tma=allow_tma, padding_mode=PaddingMode.ZERO ).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] - dq = ct.full((TILE_M, TILE_D), 0., dtype=np.float32) # [TILE_M, TILE_D] - + dq = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) # [TILE_M, TILE_D] # loop over k, v and update accumulator - # m_end = input_pos + (bid_x + 1) * TILE_M k_seqlen = K.shape[2] - # if CAUSAL: - # # when kv pos could exceed q pos - # mask_start = (input_pos + bid_x * TILE_M) // TILE_N - # # when kv pos could exceed k_seqlen - # mask_start = min(mask_start, k_seqlen // TILE_N) - # Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N) - # else: - # Tc = ct.cdiv(k_seqlen, TILE_N) - # mask_start = k_seqlen // TILE_N - Tc = ct.cdiv(k_seqlen, TILE_N) - for j in range(0, Tc): - # --- Compute QK product --- + if CAUSAL: + m_end = input_pos + (bid_x + 1) * TILE_M + mask_start = (input_pos + bid_x * TILE_M) // TILE_N + mask_start = min(mask_start, k_seqlen // TILE_N) + Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N) + else: + Tc = ct.cdiv(k_seqlen, TILE_N) + mask_start = Tc + if not EVEN_K: + mask_start = Tc - 1 + for j in range(0, mask_start): k = ct.load( K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), latency=2, - allow_tma=allow_tma - ) - k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] - qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO + ).reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] - - # # --- Apply Causal Masking --- - # if (CAUSAL or not EVEN_K) and j >= mask_start: - # offs_n = j * TILE_N + offs_n_tile - # mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) - # # out of bound mask - # if not EVEN_K: - # mask = mask & (offs_n < k_seqlen) - # # causal mask - # if CAUSAL: - # mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] - # mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] - # qk += mask + qk = qk * qk_scale * INV_LOG_2 + p = ct.exp2(qk - lse_i, flush_to_zero=True) # [TILE_M, TILE_N] v = ct.load( V, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), latency=4, - allow_tma=allow_tma + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO ).reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] - dp = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) # [TILE_M, TILE_N] + dp = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) # [TILE_M, TILE_N] dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] dp = dp - delta_i + ds = p * dp # [TILE_M, TILE_N] + ds = ds.astype(k.dtype) + + dskt = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) # [TILE_M, TILE_D] + kt = ct.permute(k, (1, 0)) # [TILE_N, TILE_D] + dskt = ct.mma(ds, kt, dskt) # [TILE_M, TILE_D] + dq = dq + dskt # [TILE_M, TILE_D] + + for j in range(mask_start, Tc): + # --- Compute QK product --- + k = ct.load( + K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=2, + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO + ) + k = k.reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + qk = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + + # --- Apply Causal Masking --- + offs_n = j * TILE_N + offs_n_tile + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) + # out of bound mask + if not EVEN_K: + mask = mask & (offs_n < k_seqlen) + # causal mask + if CAUSAL: + mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] + mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] + qk += mask qk = qk * qk_scale * INV_LOG_2 p = ct.exp2(qk - lse_i, flush_to_zero=True) # [TILE_M, TILE_N] + + v = ct.load( + V, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N), + order=(0, 1, 3, 2), + latency=4, + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO + ).reshape((TILE_D, TILE_N)) # [TILE_D, TILE_N] + + dp = ct.full((TILE_M, TILE_N), 0.0, dtype=np.float32) # [TILE_M, TILE_N] + dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] + dp = dp - delta_i ds = p * dp # [TILE_M, TILE_N] ds = ds.astype(k.dtype) - dskt = ct.full((TILE_M, TILE_D), 0., dtype=np.float32) # [TILE_M, TILE_D] + dskt = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32) # [TILE_M, TILE_D] kt = ct.permute(k, (1, 0)) # [TILE_N, TILE_D] dskt = ct.mma(ds, kt, dskt) # [TILE_M, TILE_D] dq = dq + dskt # [TILE_M, TILE_D] @@ -249,48 +277,32 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in allow_tma=allow_tma ).reshape((TILE_D, TILE_N)) - dk = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] - dv = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] + dk = ct.full((TILE_N, TILE_D), 0.0, dtype=np.float32) # [TILE_N, TILE_D] + dv = ct.full((TILE_N, TILE_D), 0.0, dtype=np.float32) # [TILE_N, TILE_D] # Initialize local offsets for query/key/value tile (M or N-dimension) offs_m_tile = ct.arange(TILE_M, dtype=np.int32) + input_pos # [TILE_M] offs_m_tile = offs_m_tile[:, None] # [TILE_M, 1] offs_n = bid_x * TILE_N + ct.arange(TILE_N, dtype=np.int32) offs_n = offs_n[None, :] - k_seqlen = K.shape[2] Tr = ct.cdiv(Q.shape[2], TILE_M) - # if CAUSAL: - # m_start = bid_x * TILE_N // TILE_M - # else: - # m_start = 0 - for i in range(0, Tr): + if CAUSAL: + m_start = bid_x * TILE_N // TILE_M + mask_end = ct.cdiv((bid_x + 1) * TILE_N, TILE_M) + else: + m_start = 0 + for i in range(m_start, Tr): q = ct.load(Q, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), latency=2, allow_tma=allow_tma).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] - - # loop over k, v and update accumulator - if CAUSAL: - # when kv pos could exceed q pos - mask_start = (input_pos + i * TILE_M) // TILE_N - # when kv pos could exceed k_seqlen - mask_start = min(mask_start, k_seqlen // TILE_N) - else: - mask_start = k_seqlen // TILE_N - - # # --- Apply Causal Masking --- - # if (CAUSAL or not EVEN_K) and bid_x <= mask_start: - # offs_m = i * TILE_M + offs_m_tile - # mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) - # # out of bound mask - # if not EVEN_K: - # mask = mask & (offs_n < k_seqlen) - # # causal mask - # if CAUSAL: - # mask = mask & (offs_m >= offs_n) # [TILE_M, TILE_N] - # mask = ct.where(mask, 0.0, -np.inf) # [TILE_M, TILE_N] - # qk += mask + if (CAUSAL or not EVEN_K) and i <= mask_end: + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) + offs_m = i * TILE_M + offs_m_tile + mask = mask & (offs_m >= offs_n) + mask = ct.where(mask, 0.0, -np.inf) + qk += mask lse_i = ct.load(Lse, index=(batch_idx, head_idx, i), shape=(1, 1, TILE_M), allow_tma=allow_tma).reshape((TILE_M, 1)) @@ -301,9 +313,7 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in do = ct.load(Grad, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), latency=4, allow_tma=allow_tma).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] pt = pt.astype(do.dtype) - ptdo = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] - ptdo = ct.mma(pt, do, ptdo) # [TILE_N, TILE_D] - dv = dv + ptdo # [TILE_N, TILE_D] + dv = ct.mma(pt, do, dv) # [TILE_N, TILE_D] dp = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) # [TILE_M, TILE_N] dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] @@ -313,9 +323,7 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in ds = p * dp # [TILE_M, TILE_N] dst = ct.permute(ds, (1, 0)) # [TILE_N, TILE_M] dst = dst.astype(q.dtype) - dstq = ct.full((TILE_N, TILE_D), 0., dtype=np.float32) # [TILE_N, TILE_D] - dstq = ct.mma(dst, q, dstq) # [TILE_N, TILE_D] - dk = dk + dstq + dk = ct.mma(dst, q, dk) # [TILE_N, TILE_D] dk = dk * qk_scale dk = dk.astype(k.dtype).reshape((1, 1, TILE_N, TILE_D)) From 4f3ef4b0214bbd9eaad77261cf0f681c5d338e60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E8=B6=85?= Date: Tue, 30 Dec 2025 20:28:14 +0800 Subject: [PATCH 3/4] support dq dk dv computation in accuracy atol=1e-2, rtol=5e-2 --- test/bench_attention.py | 20 ++++++++++---------- test/kernels/attention.py | 16 ++++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/bench_attention.py b/test/bench_attention.py index 72894cc..2fa810c 100644 --- a/test/bench_attention.py +++ b/test/bench_attention.py @@ -31,16 +31,16 @@ def qkv_id(qkv_shape: tuple[tuple[int, ...], tuple[int, ...]]) -> str: # B, H, L, D ((6, 24, 67000, 64), (6, 24, 67000, 64)), # tanlan dim 64 ((6, 24, 67000, 128), (6, 24, 67000, 128)), # tanlan dim 128 - # ((6, 32, 1024, 128), (6, 32, 1024, 128)), # prefill - # ((1, 32, 1024, 64), (1, 32, 1024, 64)), + ((6, 32, 1024, 128), (6, 32, 1024, 128)), # prefill + ((1, 32, 1024, 64), (1, 32, 1024, 64)), # ((1, 32, 1024, 64), (1, 8, 1024, 64)), # prefill + gqa # ((1, 32, 1024, 128), (1, 8, 1024, 128)), # prefill + gqa - # ((1, 32, 8192, 64), (1, 32, 8192, 64)), - # ((1, 32, 8192, 128), (1, 32, 8192, 128)), + ((1, 32, 8192, 64), (1, 32, 8192, 64)), + ((1, 32, 8192, 128), (1, 32, 8192, 128)), # ((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa - # ((1, 32, 1, 128), (1, 32, 1024, 128)), # decode - # ((1, 32, 1, 64), (1, 32, 1024, 64)), # decode - # ((8, 32, 1, 128), (8, 32, 1024, 128)), # decode + ((1, 32, 1, 128), (1, 32, 1024, 128)), # decode + ((1, 32, 1, 64), (1, 32, 1024, 64)), # decode + ((8, 32, 1, 128), (8, 32, 1024, 128)), # decode # ((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa # ((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa ], @@ -72,9 +72,9 @@ def bench_fmha(qkv_shape, dtype, backend, benchmark): dq, dk, dv = backend(q, k, v, o, grad, lse, is_causal, enable_gqa) dq_ref, dk_ref, dv_ref = ref_fmha(q, k, v, ref, grad, is_causal, enable_gqa) torch.testing.assert_close(o, ref, atol=1e-2, rtol=5e-2) - torch.testing.assert_close(dq, dq_ref, atol=6e-2, rtol=5e-2) - # torch.testing.assert_close(dk, dk_ref, atol=6e-2, rtol=5e-2) - # torch.testing.assert_close(dv, dv_ref, atol=6e-2, rtol=5e-2) + torch.testing.assert_close(dq, dq_ref, atol=1e-2, rtol=5e-2) + torch.testing.assert_close(dk, dk_ref, atol=1e-2, rtol=5e-2) + torch.testing.assert_close(dv, dv_ref, atol=1e-2, rtol=5e-2) else: backend(q, k, v, o, grad, lse, is_causal, enable_gqa) ref_fmha(q, k, v, ref, grad, is_causal, enable_gqa) diff --git a/test/kernels/attention.py b/test/kernels/attention.py index c2a3e92..84f9ac4 100644 --- a/test/kernels/attention.py +++ b/test/kernels/attention.py @@ -7,7 +7,6 @@ import math from cuda.tile import RoundingMode as RMd -from cuda.tile import kernel, ByTarget from cuda.tile._numeric_semantics import PaddingMode INV_LOG_2 = 1.0 / math.log(2) @@ -268,13 +267,13 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in K, index=(batch_idx, off_kv_h, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), latency=2, - allow_tma=allow_tma + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO ).reshape((TILE_D, TILE_N)) v = ct.load( V, index=(batch_idx, off_kv_h, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), latency=2, - allow_tma=allow_tma + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO ).reshape((TILE_D, TILE_N)) dk = ct.full((TILE_N, TILE_D), 0.0, dtype=np.float32) # [TILE_N, TILE_D] @@ -294,7 +293,7 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in m_start = 0 for i in range(m_start, Tr): q = ct.load(Q, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), - latency=2, allow_tma=allow_tma).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + latency=2, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] if (CAUSAL or not EVEN_K) and i <= mask_end: @@ -305,20 +304,20 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in qk += mask lse_i = ct.load(Lse, index=(batch_idx, head_idx, i), shape=(1, 1, TILE_M), - allow_tma=allow_tma).reshape((TILE_M, 1)) + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, 1)) qk = qk * qk_scale * INV_LOG_2 p = ct.exp2(qk - lse_i, flush_to_zero=True) # [TILE_M, TILE_N] pt = ct.permute(p, (1, 0)) # [TILE_N, TILE_M] do = ct.load(Grad, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), - latency=4, allow_tma=allow_tma).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + latency=4, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] pt = pt.astype(do.dtype) dv = ct.mma(pt, do, dv) # [TILE_N, TILE_D] dp = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) # [TILE_M, TILE_N] dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] delta_i = ct.load(Delta, index=(batch_idx, head_idx, i), shape=(1, 1, TILE_M), - allow_tma=allow_tma).reshape((TILE_M, 1)) + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, 1)) dp = dp - delta_i ds = p * dp # [TILE_M, TILE_N] dst = ct.permute(ds, (1, 0)) # [TILE_N, TILE_M] @@ -349,5 +348,6 @@ def fmha_bwd_preprocess_kernel(O, Grad, Delta, shape=(1, 1, TILE_M, TILE_D), latency=2, allow_tma=allow_tma ).reshape((TILE_M, TILE_D)) - delta = ct.sum(o * do, axis=1).reshape((1, 1, TILE_M)) + delta = ct.mul(o.astype(ct.float32), do.astype(ct.float32), flush_to_zero=True) + delta = ct.sum(delta, axis=1).reshape((1, 1, TILE_M)) ct.store(Delta, index=(batch_idx, head_idx, bid_x), tile=delta) \ No newline at end of file From 4858282ec54ebe48befb92ae26f0c6903eece738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E8=B6=85?= Date: Mon, 5 Jan 2026 11:09:38 +0800 Subject: [PATCH 4/4] support gqa --- test/bench_attention.py | 18 +++++----- test/kernels/attention.py | 74 ++++++++++++++++++--------------------- 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/test/bench_attention.py b/test/bench_attention.py index 2fa810c..6303ab7 100644 --- a/test/bench_attention.py +++ b/test/bench_attention.py @@ -33,16 +33,16 @@ def qkv_id(qkv_shape: tuple[tuple[int, ...], tuple[int, ...]]) -> str: ((6, 24, 67000, 128), (6, 24, 67000, 128)), # tanlan dim 128 ((6, 32, 1024, 128), (6, 32, 1024, 128)), # prefill ((1, 32, 1024, 64), (1, 32, 1024, 64)), - # ((1, 32, 1024, 64), (1, 8, 1024, 64)), # prefill + gqa - # ((1, 32, 1024, 128), (1, 8, 1024, 128)), # prefill + gqa + ((1, 32, 1024, 64), (1, 8, 1024, 64)), # prefill + gqa + ((1, 32, 1024, 128), (1, 8, 1024, 128)), # prefill + gqa ((1, 32, 8192, 64), (1, 32, 8192, 64)), ((1, 32, 8192, 128), (1, 32, 8192, 128)), - # ((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa + ((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa ((1, 32, 1, 128), (1, 32, 1024, 128)), # decode ((1, 32, 1, 64), (1, 32, 1024, 64)), # decode ((8, 32, 1, 128), (8, 32, 1024, 128)), # decode - # ((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa - # ((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa + ((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa + ((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa ], ids=qkv_id) def qkv_shape(request): @@ -109,10 +109,10 @@ def cutile_fmha(q, k, v, o, grad, lse, is_causal, enable_gqa): qk_scale = 1 / sqrt(d) TILE_M, TILE_N = (256, 128) if is_causal else (64, 128) query_group_size = qh // kh - grid = (ceil(q_len / TILE_M), b * qh, 1) input_pos = 0 if q_len == k_len else (k_len - 1) EVEN_K = (k_len % TILE_N) == 0 - ct.launch(torch.cuda.current_stream(), grid, fmha_kernel, + ct.launch(torch.cuda.current_stream(), (ceil(q_len / TILE_M), b * qh, 1), + fmha_kernel, (q, k, v, o, lse, qk_scale, input_pos, d, qh, TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)) @@ -124,10 +124,10 @@ def cutile_fmha(q, k, v, o, grad, lse, is_causal, enable_gqa): ct.launch(torch.cuda.current_stream(), (ceil(q_len / TILE_M), b * qh, 1), fmha_bwd_preprocess_kernel, (o, grad, delta, qh, TILE_M, d)) - ct.launch(torch.cuda.current_stream(), (ceil(k_len / TILE_N), b * qh, 1), + ct.launch(torch.cuda.current_stream(), (ceil(k_len / TILE_N), b * kh, 1), fmha_bwd_dk_dv_kernel, (q, k, v, grad, delta, lse, dk, dv, - qk_scale, input_pos, d, qh, TILE_M, TILE_N, + qk_scale, input_pos, d, kh, TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)) ct.launch(torch.cuda.current_stream(), (ceil(q_len / TILE_M), b * qh, 1), fmha_bwd_dq_kernel, diff --git a/test/kernels/attention.py b/test/kernels/attention.py index 84f9ac4..cdc24cb 100644 --- a/test/kernels/attention.py +++ b/test/kernels/attention.py @@ -250,7 +250,6 @@ def fmha_bwd_dq_kernel(Q, K, V, Grad, Delta, Lse, DQ, qk_scale: float, input_pos dq = dq.astype(q.dtype).reshape((1, 1, TILE_M, TILE_D)) ct.store(DQ, index=(batch_idx, head_idx, bid_x, 0), tile=dq) - @ct.kernel(occupancy=2) def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, input_pos: int, TILE_D: ConstInt, @@ -260,19 +259,15 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in bid_y = ct.bid(1) batch_idx = bid_y // H head_idx = bid_y % H - off_kv_h = head_idx // QUERY_GROUP_SIZE - # Adjust qk_scale for exp2 k = ct.load( - K, index=(batch_idx, off_kv_h, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), + K, index=(batch_idx, head_idx, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), - latency=2, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO ).reshape((TILE_D, TILE_N)) v = ct.load( - V, index=(batch_idx, off_kv_h, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), + V, index=(batch_idx, head_idx, 0, bid_x), shape=(1, 1, TILE_D, TILE_N), order=(0, 1, 3, 2), - latency=2, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO ).reshape((TILE_D, TILE_N)) @@ -291,38 +286,39 @@ def fmha_bwd_dk_dv_kernel(Q, K, V, Grad, Delta, Lse, DK, DV, qk_scale: float, in mask_end = ct.cdiv((bid_x + 1) * TILE_N, TILE_M) else: m_start = 0 - for i in range(m_start, Tr): - q = ct.load(Q, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), - latency=2, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] - qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) - qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] - if (CAUSAL or not EVEN_K) and i <= mask_end: - mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) - offs_m = i * TILE_M + offs_m_tile - mask = mask & (offs_m >= offs_n) - mask = ct.where(mask, 0.0, -np.inf) - qk += mask - - lse_i = ct.load(Lse, index=(batch_idx, head_idx, i), shape=(1, 1, TILE_M), - allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, 1)) - qk = qk * qk_scale * INV_LOG_2 - p = ct.exp2(qk - lse_i, flush_to_zero=True) # [TILE_M, TILE_N] - pt = ct.permute(p, (1, 0)) # [TILE_N, TILE_M] - - do = ct.load(Grad, index=(batch_idx, head_idx, i, 0), shape=(1, 1, TILE_M, TILE_D), - latency=4, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] - pt = pt.astype(do.dtype) - dv = ct.mma(pt, do, dv) # [TILE_N, TILE_D] - - dp = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) # [TILE_M, TILE_N] - dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] - delta_i = ct.load(Delta, index=(batch_idx, head_idx, i), shape=(1, 1, TILE_M), - allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, 1)) - dp = dp - delta_i - ds = p * dp # [TILE_M, TILE_N] - dst = ct.permute(ds, (1, 0)) # [TILE_N, TILE_M] - dst = dst.astype(q.dtype) - dk = ct.mma(dst, q, dk) # [TILE_N, TILE_D] + for j in range(QUERY_GROUP_SIZE): + for i in range(m_start, Tr): + q = ct.load(Q, index=(batch_idx, head_idx * QUERY_GROUP_SIZE + j, i, 0), shape=(1, 1, TILE_M, TILE_D), + latency=2, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) + qk = ct.mma(q, k, qk) # [TILE_M, TILE_N] + if (CAUSAL or not EVEN_K) and i <= mask_end: + mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool_) + offs_m = i * TILE_M + offs_m_tile + mask = mask & (offs_m >= offs_n) + mask = ct.where(mask, 0.0, -np.inf) + qk += mask + + lse_i = ct.load(Lse, index=(batch_idx, head_idx * QUERY_GROUP_SIZE + j, i), shape=(1, 1, TILE_M), + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, 1)) + qk = qk * qk_scale * INV_LOG_2 + p = ct.exp2(qk - lse_i, flush_to_zero=True) # [TILE_M, TILE_N] + pt = ct.permute(p, (1, 0)) # [TILE_N, TILE_M] + + do = ct.load(Grad, index=(batch_idx, head_idx * QUERY_GROUP_SIZE + j, i, 0), shape=(1, 1, TILE_M, TILE_D), + latency=4, allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, TILE_D)) # [TILE_M, TILE_D] + pt = pt.astype(do.dtype) + dv = ct.mma(pt, do, dv) # [TILE_N, TILE_D] + + dp = ct.full((TILE_M, TILE_N), 0., dtype=np.float32) # [TILE_M, TILE_N] + dp = ct.mma(do, v, dp) # [TILE_M, TILE_N] + delta_i = ct.load(Delta, index=(batch_idx, head_idx * QUERY_GROUP_SIZE + j, i), shape=(1, 1, TILE_M), + allow_tma=allow_tma, padding_mode=PaddingMode.ZERO).reshape((TILE_M, 1)) + dp = dp - delta_i + ds = p * dp # [TILE_M, TILE_N] + dst = ct.permute(ds, (1, 0)) # [TILE_N, TILE_M] + dst = dst.astype(q.dtype) + dk = ct.mma(dst, q, dk) # [TILE_N, TILE_D] dk = dk * qk_scale dk = dk.astype(k.dtype).reshape((1, 1, TILE_N, TILE_D))