|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +"""Tests for CuTeDSL backends of BMM operations. |
| 16 | +
|
| 17 | +ALL TESTS PASSING ON H100 (SM90)! ✓ |
| 18 | +
|
| 19 | +This test file covers multiple BMM layout combinations with CuTeDSL kernels: |
| 20 | + - bmm_rcr: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T (IDEAL - no transpose needed) ✓ |
| 21 | + - bmm_rcr_add: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T + D[B,M,N] ✓ |
| 22 | + - bmm_rrr: C[B,M,N] = A[B,M,K] @ B[B,K,N] ✓ |
| 23 | + - bmm_ccr: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T ✓ |
| 24 | + - bmm_rrr_add: C[B,M,N] = A[B,M,K] @ B[B,K,N] + D[B,M,N] ✓ |
| 25 | + - bmm_ccr_add: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T + D[B,M,N] ✓ |
| 26 | +
|
| 27 | +Key fix: The tensor descriptor format was corrected to: |
| 28 | + - dynamic_shapes[0] = B_dim (batch size) |
| 29 | + - dynamic_shapes[1] = first inner dimension extent (M, K, or N) |
| 30 | + - dynamic_strides[0] = batch stride |
| 31 | +
|
| 32 | +MMA Layout Requirements: |
| 33 | + - A operand: [M, K] with K contiguous |
| 34 | + - B operand: [N, K] with K contiguous |
| 35 | +
|
| 36 | +Layout analysis: |
| 37 | + - RCR: A[M,K] K-contiguous ✓, B[N,K] K-contiguous ✓ (IDEAL, no transpose) |
| 38 | + - RRR: A[M,K] K-contiguous ✓, B[K,N] needs logical transpose |
| 39 | + - CCR: A[K,M] needs logical transpose, B[N,K] K-contiguous ✓ |
| 40 | +
|
| 41 | +Run with: |
| 42 | + buck run fbcode//aitemplate/AITemplate/examples:test_cutedsl_bmm |
| 43 | +""" |
| 44 | + |
| 45 | +import unittest |
| 46 | + |
| 47 | +import torch |
| 48 | +from aitemplate.compiler import compile_model, ops |
| 49 | +from aitemplate.frontend import Tensor |
| 50 | +from aitemplate.testing.detect_target import FBCUDA |
| 51 | + |
| 52 | + |
| 53 | +def _get_target(**kwargs): |
| 54 | + cc_major, cc_minor = torch.cuda.get_device_capability(0) |
| 55 | + gpu_arch = str(cc_major * 10 + cc_minor) |
| 56 | + if int(gpu_arch) < 80: |
| 57 | + raise RuntimeError(f"SM80+ required, got SM{gpu_arch}") |
| 58 | + return FBCUDA(arch=gpu_arch, **kwargs) |
| 59 | + |
| 60 | + |
| 61 | +# All tests passing on H100 (SM90)! |
| 62 | +class CuTeDSLBmmTest(unittest.TestCase): |
| 63 | + """Tests for basic BMM operations (bmm_ccr, bmm_rrr). |
| 64 | +
|
| 65 | + These tests use layouts that require logical transpose handling. |
| 66 | + Fixed by correcting the tensor descriptor format in the wrapper. |
| 67 | + """ |
| 68 | + |
| 69 | + def test_bmm_rrr(self): |
| 70 | + """Test bmm_rrr: C[B,M,N] = A[B,M,K] @ B[B,K,N]""" |
| 71 | + B, M, N, K = 2, 256, 512, 128 |
| 72 | + dtype = "float16" |
| 73 | + |
| 74 | + A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True) |
| 75 | + W = Tensor(shape=[B, K, N], dtype=dtype, name="W", is_input=True) |
| 76 | + Y = ops.bmm_rrr()(A, W) |
| 77 | + Y._attrs["name"] = "Y" |
| 78 | + Y._attrs["is_output"] = True |
| 79 | + |
| 80 | + target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True) |
| 81 | + with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rrr") as module: |
| 82 | + a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16) |
| 83 | + w_pt = torch.randn(B, K, N, device="cuda", dtype=torch.float16) |
| 84 | + |
| 85 | + # PyTorch reference: standard batched matmul |
| 86 | + y_ref = torch.bmm(a_pt, w_pt) |
| 87 | + |
| 88 | + y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16) |
| 89 | + module.run_with_tensors( |
| 90 | + {"A": a_pt, "W": w_pt}, |
| 91 | + {"Y": y_ait}, |
| 92 | + ) |
| 93 | + |
| 94 | + self.assertTrue( |
| 95 | + torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1), |
| 96 | + f"bmm_rrr: max diff = {(y_ait - y_ref).abs().max().item():.6f}", |
| 97 | + ) |
| 98 | + |
| 99 | + def test_bmm_ccr(self): |
| 100 | + """Test bmm_ccr: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T""" |
| 101 | + B, M, N, K = 2, 256, 512, 128 |
| 102 | + dtype = "float16" |
| 103 | + |
| 104 | + # A is col-major [B, K, M], B is col-major [B, N, K] |
| 105 | + A = Tensor(shape=[B, K, M], dtype=dtype, name="A", is_input=True) |
| 106 | + W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True) |
| 107 | + Y = ops.bmm_ccr()(A, W) |
| 108 | + Y._attrs["name"] = "Y" |
| 109 | + Y._attrs["is_output"] = True |
| 110 | + |
| 111 | + target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True) |
| 112 | + with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_ccr") as module: |
| 113 | + a_pt = torch.randn(B, K, M, device="cuda", dtype=torch.float16) |
| 114 | + w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16) |
| 115 | + |
| 116 | + # PyTorch reference: |
| 117 | + # bmm_ccr: A^T @ W^T = transpose(A, -2, -1) @ transpose(W, -2, -1) |
| 118 | + a_t = a_pt.transpose(-2, -1) # [B, M, K] |
| 119 | + w_t = w_pt.transpose(-2, -1) # [B, K, N] |
| 120 | + y_ref = torch.bmm(a_t, w_t) |
| 121 | + |
| 122 | + y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16) |
| 123 | + module.run_with_tensors( |
| 124 | + {"A": a_pt, "W": w_pt}, |
| 125 | + {"Y": y_ait}, |
| 126 | + ) |
| 127 | + |
| 128 | + self.assertTrue( |
| 129 | + torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1), |
| 130 | + f"bmm_ccr: max diff = {(y_ait - y_ref).abs().max().item():.6f}", |
| 131 | + ) |
| 132 | + |
| 133 | + |
| 134 | +# All tests passing on H100 (SM90)! |
| 135 | +class CuTeDSLBmmAddTest(unittest.TestCase): |
| 136 | + """Tests for BMM with residual add (bmm_ccr_add, bmm_rrr_add). |
| 137 | +
|
| 138 | + Fixed by correcting the tensor descriptor format in the wrapper. |
| 139 | + """ |
| 140 | + |
| 141 | + def test_bmm_rrr_add(self): |
| 142 | + """Test bmm_rrr_add: C[B,M,N] = A[B,M,K] @ B[B,K,N] + D[B,M,N]""" |
| 143 | + B, M, N, K = 2, 256, 512, 128 |
| 144 | + dtype = "float16" |
| 145 | + |
| 146 | + A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True) |
| 147 | + W = Tensor(shape=[B, K, N], dtype=dtype, name="W", is_input=True) |
| 148 | + D = Tensor(shape=[B, M, N], dtype=dtype, name="D", is_input=True) |
| 149 | + Y = ops.bmm_rrr_add()(A, W, D) |
| 150 | + Y._attrs["name"] = "Y" |
| 151 | + Y._attrs["is_output"] = True |
| 152 | + |
| 153 | + target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True) |
| 154 | + with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rrr_add") as module: |
| 155 | + a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16) |
| 156 | + w_pt = torch.randn(B, K, N, device="cuda", dtype=torch.float16) |
| 157 | + d_pt = torch.randn(B, M, N, device="cuda", dtype=torch.float16) |
| 158 | + |
| 159 | + # PyTorch reference: bmm + add |
| 160 | + y_ref = torch.bmm(a_pt, w_pt) + d_pt |
| 161 | + |
| 162 | + y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16) |
| 163 | + module.run_with_tensors( |
| 164 | + {"A": a_pt, "W": w_pt, "D": d_pt}, |
| 165 | + {"Y": y_ait}, |
| 166 | + ) |
| 167 | + |
| 168 | + self.assertTrue( |
| 169 | + torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1), |
| 170 | + f"bmm_rrr_add: max diff = {(y_ait - y_ref).abs().max().item():.6f}", |
| 171 | + ) |
| 172 | + |
| 173 | + def test_bmm_ccr_add(self): |
| 174 | + """Test bmm_ccr_add: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T + D[B,M,N]""" |
| 175 | + B, M, N, K = 2, 256, 512, 128 |
| 176 | + dtype = "float16" |
| 177 | + |
| 178 | + # A is col-major [B, K, M], B is col-major [B, N, K] |
| 179 | + A = Tensor(shape=[B, K, M], dtype=dtype, name="A", is_input=True) |
| 180 | + W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True) |
| 181 | + D = Tensor(shape=[B, M, N], dtype=dtype, name="D", is_input=True) |
| 182 | + Y = ops.bmm_ccr_add()(A, W, D) |
| 183 | + Y._attrs["name"] = "Y" |
| 184 | + Y._attrs["is_output"] = True |
| 185 | + |
| 186 | + target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True) |
| 187 | + with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_ccr_add") as module: |
| 188 | + a_pt = torch.randn(B, K, M, device="cuda", dtype=torch.float16) |
| 189 | + w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16) |
| 190 | + d_pt = torch.randn(B, M, N, device="cuda", dtype=torch.float16) |
| 191 | + |
| 192 | + # PyTorch reference: |
| 193 | + # bmm_ccr_add: A^T @ W^T + D |
| 194 | + a_t = a_pt.transpose(-2, -1) # [B, M, K] |
| 195 | + w_t = w_pt.transpose(-2, -1) # [B, K, N] |
| 196 | + y_ref = torch.bmm(a_t, w_t) + d_pt |
| 197 | + |
| 198 | + y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16) |
| 199 | + module.run_with_tensors( |
| 200 | + {"A": a_pt, "W": w_pt, "D": d_pt}, |
| 201 | + {"Y": y_ait}, |
| 202 | + ) |
| 203 | + |
| 204 | + self.assertTrue( |
| 205 | + torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1), |
| 206 | + f"bmm_ccr_add: max diff = {(y_ait - y_ref).abs().max().item():.6f}", |
| 207 | + ) |
| 208 | + |
| 209 | + |
| 210 | +# ============================================================================= |
| 211 | +# Tests for bmm_rcr layout - WORKING (no transpose needed) |
| 212 | +# ============================================================================= |
| 213 | + |
| 214 | + |
| 215 | +class CuTeDSLBmmRcrTest(unittest.TestCase): |
| 216 | + """Tests for BMM RCR layout - IDEAL for CuTeDSL MMA. |
| 217 | +
|
| 218 | + bmm_rcr: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T |
| 219 | +
|
| 220 | + Why this works: |
| 221 | + - A[M,K] row-major has K contiguous ✓ (matches MMA A operand) |
| 222 | + - B[N,K] col-major has K contiguous ✓ (matches MMA B operand) |
| 223 | + - Both operands match MMA requirements - NO TRANSPOSE NEEDED! |
| 224 | +
|
| 225 | + This is equivalent to torch.bmm(A, B.transpose(-2, -1)). |
| 226 | + """ |
| 227 | + |
| 228 | + def test_bmm_rcr(self): |
| 229 | + """Test bmm_rcr: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T""" |
| 230 | + B, M, N, K = 2, 256, 512, 128 |
| 231 | + dtype = "float16" |
| 232 | + |
| 233 | + # A is row-major [B, M, K], B is col-major [B, N, K] |
| 234 | + A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True) |
| 235 | + W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True) |
| 236 | + Y = ops.bmm_rcr()(A, W) |
| 237 | + Y._attrs["name"] = "Y" |
| 238 | + Y._attrs["is_output"] = True |
| 239 | + |
| 240 | + target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True) |
| 241 | + with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rcr") as module: |
| 242 | + a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16) |
| 243 | + w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16) |
| 244 | + |
| 245 | + # PyTorch reference: A @ W^T |
| 246 | + y_ref = torch.bmm(a_pt, w_pt.transpose(-2, -1)) |
| 247 | + |
| 248 | + y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16) |
| 249 | + module.run_with_tensors( |
| 250 | + {"A": a_pt, "W": w_pt}, |
| 251 | + {"Y": y_ait}, |
| 252 | + ) |
| 253 | + |
| 254 | + self.assertTrue( |
| 255 | + torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1), |
| 256 | + f"bmm_rcr: max diff = {(y_ait - y_ref).abs().max().item():.6f}", |
| 257 | + ) |
| 258 | + |
| 259 | + |
| 260 | +class CuTeDSLBmmRcrAddTest(unittest.TestCase): |
| 261 | + """Tests for BMM RCR with residual add - WORKING. |
| 262 | +
|
| 263 | + bmm_rcr_add: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T + D[B,M,N] |
| 264 | + """ |
| 265 | + |
| 266 | + def test_bmm_rcr_add(self): |
| 267 | + """Test bmm_rcr_add: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T + D[B,M,N]""" |
| 268 | + B, M, N, K = 2, 256, 512, 128 |
| 269 | + dtype = "float16" |
| 270 | + |
| 271 | + A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True) |
| 272 | + W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True) |
| 273 | + D = Tensor(shape=[B, M, N], dtype=dtype, name="D", is_input=True) |
| 274 | + Y = ops.bmm_rcr_add()(A, W, D) |
| 275 | + Y._attrs["name"] = "Y" |
| 276 | + Y._attrs["is_output"] = True |
| 277 | + |
| 278 | + target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True) |
| 279 | + with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rcr_add") as module: |
| 280 | + a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16) |
| 281 | + w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16) |
| 282 | + d_pt = torch.randn(B, M, N, device="cuda", dtype=torch.float16) |
| 283 | + |
| 284 | + # PyTorch reference: A @ W^T + D |
| 285 | + y_ref = torch.bmm(a_pt, w_pt.transpose(-2, -1)) + d_pt |
| 286 | + |
| 287 | + y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16) |
| 288 | + module.run_with_tensors( |
| 289 | + {"A": a_pt, "W": w_pt, "D": d_pt}, |
| 290 | + {"Y": y_ait}, |
| 291 | + ) |
| 292 | + |
| 293 | + self.assertTrue( |
| 294 | + torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1), |
| 295 | + f"bmm_rcr_add: max diff = {(y_ait - y_ref).abs().max().item():.6f}", |
| 296 | + ) |
| 297 | + |
| 298 | + |
| 299 | +if __name__ == "__main__": |
| 300 | + unittest.main() |
0 commit comments