Skip to content

Commit e0dbf98

Browse files
zoranzhaofacebook-github-bot
authored andcommitted
CuTeDSL bmm _xxx and bmm_xxx_add backend on ampere/hopper (#1056)
Summary: Pull Request resolved: #1056 Reviewed By: jijunyan Differential Revision: D95449060
1 parent d0b1173 commit e0dbf98

File tree

6 files changed

+2811
-0
lines changed

6 files changed

+2811
-0
lines changed

examples/test_cutedsl_bmm.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Tests for CuTeDSL backends of BMM operations.
16+
17+
ALL TESTS PASSING ON H100 (SM90)! ✓
18+
19+
This test file covers multiple BMM layout combinations with CuTeDSL kernels:
20+
- bmm_rcr: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T (IDEAL - no transpose needed) ✓
21+
- bmm_rcr_add: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T + D[B,M,N] ✓
22+
- bmm_rrr: C[B,M,N] = A[B,M,K] @ B[B,K,N] ✓
23+
- bmm_ccr: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T ✓
24+
- bmm_rrr_add: C[B,M,N] = A[B,M,K] @ B[B,K,N] + D[B,M,N] ✓
25+
- bmm_ccr_add: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T + D[B,M,N] ✓
26+
27+
Key fix: The tensor descriptor format was corrected to:
28+
- dynamic_shapes[0] = B_dim (batch size)
29+
- dynamic_shapes[1] = first inner dimension extent (M, K, or N)
30+
- dynamic_strides[0] = batch stride
31+
32+
MMA Layout Requirements:
33+
- A operand: [M, K] with K contiguous
34+
- B operand: [N, K] with K contiguous
35+
36+
Layout analysis:
37+
- RCR: A[M,K] K-contiguous ✓, B[N,K] K-contiguous ✓ (IDEAL, no transpose)
38+
- RRR: A[M,K] K-contiguous ✓, B[K,N] needs logical transpose
39+
- CCR: A[K,M] needs logical transpose, B[N,K] K-contiguous ✓
40+
41+
Run with:
42+
buck run fbcode//aitemplate/AITemplate/examples:test_cutedsl_bmm
43+
"""
44+
45+
import unittest
46+
47+
import torch
48+
from aitemplate.compiler import compile_model, ops
49+
from aitemplate.frontend import Tensor
50+
from aitemplate.testing.detect_target import FBCUDA
51+
52+
53+
def _get_target(**kwargs):
54+
cc_major, cc_minor = torch.cuda.get_device_capability(0)
55+
gpu_arch = str(cc_major * 10 + cc_minor)
56+
if int(gpu_arch) < 80:
57+
raise RuntimeError(f"SM80+ required, got SM{gpu_arch}")
58+
return FBCUDA(arch=gpu_arch, **kwargs)
59+
60+
61+
# All tests passing on H100 (SM90)!
62+
class CuTeDSLBmmTest(unittest.TestCase):
63+
"""Tests for basic BMM operations (bmm_ccr, bmm_rrr).
64+
65+
These tests use layouts that require logical transpose handling.
66+
Fixed by correcting the tensor descriptor format in the wrapper.
67+
"""
68+
69+
def test_bmm_rrr(self):
70+
"""Test bmm_rrr: C[B,M,N] = A[B,M,K] @ B[B,K,N]"""
71+
B, M, N, K = 2, 256, 512, 128
72+
dtype = "float16"
73+
74+
A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True)
75+
W = Tensor(shape=[B, K, N], dtype=dtype, name="W", is_input=True)
76+
Y = ops.bmm_rrr()(A, W)
77+
Y._attrs["name"] = "Y"
78+
Y._attrs["is_output"] = True
79+
80+
target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True)
81+
with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rrr") as module:
82+
a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16)
83+
w_pt = torch.randn(B, K, N, device="cuda", dtype=torch.float16)
84+
85+
# PyTorch reference: standard batched matmul
86+
y_ref = torch.bmm(a_pt, w_pt)
87+
88+
y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16)
89+
module.run_with_tensors(
90+
{"A": a_pt, "W": w_pt},
91+
{"Y": y_ait},
92+
)
93+
94+
self.assertTrue(
95+
torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1),
96+
f"bmm_rrr: max diff = {(y_ait - y_ref).abs().max().item():.6f}",
97+
)
98+
99+
def test_bmm_ccr(self):
100+
"""Test bmm_ccr: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T"""
101+
B, M, N, K = 2, 256, 512, 128
102+
dtype = "float16"
103+
104+
# A is col-major [B, K, M], B is col-major [B, N, K]
105+
A = Tensor(shape=[B, K, M], dtype=dtype, name="A", is_input=True)
106+
W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True)
107+
Y = ops.bmm_ccr()(A, W)
108+
Y._attrs["name"] = "Y"
109+
Y._attrs["is_output"] = True
110+
111+
target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True)
112+
with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_ccr") as module:
113+
a_pt = torch.randn(B, K, M, device="cuda", dtype=torch.float16)
114+
w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16)
115+
116+
# PyTorch reference:
117+
# bmm_ccr: A^T @ W^T = transpose(A, -2, -1) @ transpose(W, -2, -1)
118+
a_t = a_pt.transpose(-2, -1) # [B, M, K]
119+
w_t = w_pt.transpose(-2, -1) # [B, K, N]
120+
y_ref = torch.bmm(a_t, w_t)
121+
122+
y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16)
123+
module.run_with_tensors(
124+
{"A": a_pt, "W": w_pt},
125+
{"Y": y_ait},
126+
)
127+
128+
self.assertTrue(
129+
torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1),
130+
f"bmm_ccr: max diff = {(y_ait - y_ref).abs().max().item():.6f}",
131+
)
132+
133+
134+
# All tests passing on H100 (SM90)!
135+
class CuTeDSLBmmAddTest(unittest.TestCase):
136+
"""Tests for BMM with residual add (bmm_ccr_add, bmm_rrr_add).
137+
138+
Fixed by correcting the tensor descriptor format in the wrapper.
139+
"""
140+
141+
def test_bmm_rrr_add(self):
142+
"""Test bmm_rrr_add: C[B,M,N] = A[B,M,K] @ B[B,K,N] + D[B,M,N]"""
143+
B, M, N, K = 2, 256, 512, 128
144+
dtype = "float16"
145+
146+
A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True)
147+
W = Tensor(shape=[B, K, N], dtype=dtype, name="W", is_input=True)
148+
D = Tensor(shape=[B, M, N], dtype=dtype, name="D", is_input=True)
149+
Y = ops.bmm_rrr_add()(A, W, D)
150+
Y._attrs["name"] = "Y"
151+
Y._attrs["is_output"] = True
152+
153+
target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True)
154+
with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rrr_add") as module:
155+
a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16)
156+
w_pt = torch.randn(B, K, N, device="cuda", dtype=torch.float16)
157+
d_pt = torch.randn(B, M, N, device="cuda", dtype=torch.float16)
158+
159+
# PyTorch reference: bmm + add
160+
y_ref = torch.bmm(a_pt, w_pt) + d_pt
161+
162+
y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16)
163+
module.run_with_tensors(
164+
{"A": a_pt, "W": w_pt, "D": d_pt},
165+
{"Y": y_ait},
166+
)
167+
168+
self.assertTrue(
169+
torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1),
170+
f"bmm_rrr_add: max diff = {(y_ait - y_ref).abs().max().item():.6f}",
171+
)
172+
173+
def test_bmm_ccr_add(self):
174+
"""Test bmm_ccr_add: C[B,M,N] = A[B,K,M]^T @ B[B,N,K]^T + D[B,M,N]"""
175+
B, M, N, K = 2, 256, 512, 128
176+
dtype = "float16"
177+
178+
# A is col-major [B, K, M], B is col-major [B, N, K]
179+
A = Tensor(shape=[B, K, M], dtype=dtype, name="A", is_input=True)
180+
W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True)
181+
D = Tensor(shape=[B, M, N], dtype=dtype, name="D", is_input=True)
182+
Y = ops.bmm_ccr_add()(A, W, D)
183+
Y._attrs["name"] = "Y"
184+
Y._attrs["is_output"] = True
185+
186+
target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True)
187+
with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_ccr_add") as module:
188+
a_pt = torch.randn(B, K, M, device="cuda", dtype=torch.float16)
189+
w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16)
190+
d_pt = torch.randn(B, M, N, device="cuda", dtype=torch.float16)
191+
192+
# PyTorch reference:
193+
# bmm_ccr_add: A^T @ W^T + D
194+
a_t = a_pt.transpose(-2, -1) # [B, M, K]
195+
w_t = w_pt.transpose(-2, -1) # [B, K, N]
196+
y_ref = torch.bmm(a_t, w_t) + d_pt
197+
198+
y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16)
199+
module.run_with_tensors(
200+
{"A": a_pt, "W": w_pt, "D": d_pt},
201+
{"Y": y_ait},
202+
)
203+
204+
self.assertTrue(
205+
torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1),
206+
f"bmm_ccr_add: max diff = {(y_ait - y_ref).abs().max().item():.6f}",
207+
)
208+
209+
210+
# =============================================================================
211+
# Tests for bmm_rcr layout - WORKING (no transpose needed)
212+
# =============================================================================
213+
214+
215+
class CuTeDSLBmmRcrTest(unittest.TestCase):
216+
"""Tests for BMM RCR layout - IDEAL for CuTeDSL MMA.
217+
218+
bmm_rcr: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T
219+
220+
Why this works:
221+
- A[M,K] row-major has K contiguous ✓ (matches MMA A operand)
222+
- B[N,K] col-major has K contiguous ✓ (matches MMA B operand)
223+
- Both operands match MMA requirements - NO TRANSPOSE NEEDED!
224+
225+
This is equivalent to torch.bmm(A, B.transpose(-2, -1)).
226+
"""
227+
228+
def test_bmm_rcr(self):
229+
"""Test bmm_rcr: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T"""
230+
B, M, N, K = 2, 256, 512, 128
231+
dtype = "float16"
232+
233+
# A is row-major [B, M, K], B is col-major [B, N, K]
234+
A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True)
235+
W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True)
236+
Y = ops.bmm_rcr()(A, W)
237+
Y._attrs["name"] = "Y"
238+
Y._attrs["is_output"] = True
239+
240+
target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True)
241+
with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rcr") as module:
242+
a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16)
243+
w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16)
244+
245+
# PyTorch reference: A @ W^T
246+
y_ref = torch.bmm(a_pt, w_pt.transpose(-2, -1))
247+
248+
y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16)
249+
module.run_with_tensors(
250+
{"A": a_pt, "W": w_pt},
251+
{"Y": y_ait},
252+
)
253+
254+
self.assertTrue(
255+
torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1),
256+
f"bmm_rcr: max diff = {(y_ait - y_ref).abs().max().item():.6f}",
257+
)
258+
259+
260+
class CuTeDSLBmmRcrAddTest(unittest.TestCase):
261+
"""Tests for BMM RCR with residual add - WORKING.
262+
263+
bmm_rcr_add: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T + D[B,M,N]
264+
"""
265+
266+
def test_bmm_rcr_add(self):
267+
"""Test bmm_rcr_add: C[B,M,N] = A[B,M,K] @ B[B,N,K]^T + D[B,M,N]"""
268+
B, M, N, K = 2, 256, 512, 128
269+
dtype = "float16"
270+
271+
A = Tensor(shape=[B, M, K], dtype=dtype, name="A", is_input=True)
272+
W = Tensor(shape=[B, N, K], dtype=dtype, name="W", is_input=True)
273+
D = Tensor(shape=[B, M, N], dtype=dtype, name="D", is_input=True)
274+
Y = ops.bmm_rcr_add()(A, W, D)
275+
Y._attrs["name"] = "Y"
276+
Y._attrs["is_output"] = True
277+
278+
target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=True)
279+
with compile_model(Y, target, "./tmp", "test_cutedsl_bmm_rcr_add") as module:
280+
a_pt = torch.randn(B, M, K, device="cuda", dtype=torch.float16)
281+
w_pt = torch.randn(B, N, K, device="cuda", dtype=torch.float16)
282+
d_pt = torch.randn(B, M, N, device="cuda", dtype=torch.float16)
283+
284+
# PyTorch reference: A @ W^T + D
285+
y_ref = torch.bmm(a_pt, w_pt.transpose(-2, -1)) + d_pt
286+
287+
y_ait = torch.empty(B, M, N, device="cuda", dtype=torch.float16)
288+
module.run_with_tensors(
289+
{"A": a_pt, "W": w_pt, "D": d_pt},
290+
{"Y": y_ait},
291+
)
292+
293+
self.assertTrue(
294+
torch.allclose(y_ait, y_ref, atol=1e-1, rtol=1e-1),
295+
f"bmm_rcr_add: max diff = {(y_ait - y_ref).abs().max().item():.6f}",
296+
)
297+
298+
299+
if __name__ == "__main__":
300+
unittest.main()

python/aitemplate/backend/cuda/gemm_universal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
bmm_rrr_permute,
1919
bmm_xxx,
2020
bmm_xxx_add,
21+
cutedsl_bmm,
2122
gemm_rcr_bias,
2223
gemm_rcr_bias_activation_cutedsl,
2324
gemm_rcr_bias_cutedsl,

0 commit comments

Comments
 (0)