Skip to content

Commit ea34233

Browse files
zoranzhaometa-codesync[bot]
authored andcommitted
Hovering AIT backend to CuTeDSL (#1053)
Summary: Pull Request resolved: #1053 Pull Request resolved: #1051 As title Reviewed By: jijunyan Differential Revision: D94613742 fbshipit-source-id: 95978c767f494173ea92e1589274f87d174bcbca
1 parent 510b757 commit ea34233

File tree

10 files changed

+2155
-14
lines changed

10 files changed

+2155
-14
lines changed

examples/07_how_to_run_pt_model/classic_b2b_bmm_example.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,21 @@ def build_decomposed_b2b_bmm_graph(batch, seq_len, head_dim, dtype="float16"):
141141
# =============================================================================
142142

143143

144-
def run_pattern_matching_example():
144+
def run_pattern_matching_example(use_cutedsl=False):
145145
"""Test: Decomposed ops auto-fused into classic_b2b_bmm by compiler pass.
146146
147147
Builds an AIT graph from primitive ops (bmm_rcr, elementwise MUL/ADD/SIGMOID,
148148
bmm_rrr) and verifies that the fuse_b2b_bmm pass fuses them into a single
149149
classic_b2b_bmm kernel, producing results matching PyTorch.
150+
151+
Parameters
152+
----------
153+
use_cutedsl : bool
154+
If True, use CuTeDSL backend instead of CUTLASS C++ templates.
150155
"""
156+
backend_name = "CuTeDSL" if use_cutedsl else "CUTLASS C++"
151157
print("\n" + "=" * 60)
152-
print("Pattern Matching Test: decomposed ops -> classic_b2b_bmm")
158+
print(f"Pattern Matching Test: decomposed ops -> classic_b2b_bmm ({backend_name})")
153159
print("=" * 60)
154160

155161
batch, seq_len, head_dim = 4, 128, 64
@@ -166,15 +172,19 @@ def run_pattern_matching_example():
166172
y_pt = pt_model(q_pt, k_pt, v_pt, bias_pt)
167173

168174
# Build AIT graph from decomposed ops (NOT ops.classic_b2b_bmm)
169-
target = _get_target(use_fp16_acc=False)
175+
target = _get_target(use_fp16_acc=False, use_cutedsl_b2b_bmm=use_cutedsl)
170176
logging.getLogger("aitemplate").setLevel(logging.DEBUG)
171177

172178
with target:
173179
Y = build_decomposed_b2b_bmm_graph(batch, seq_len, head_dim, dtype)
174180

175181
# Compile - the fuse_b2b_bmm pass will fuse the decomposed graph
176-
print("\nCompiling... (fuse_b2b_bmm pass will pattern-match and fuse)")
177-
with compile_model(Y, target, "./tmp", "pattern_matched_b2b_bmm") as module:
182+
workdir_suffix = "cutedsl" if use_cutedsl else "cutlass"
183+
print(f"\nCompiling with {backend_name} backend...")
184+
print("(fuse_b2b_bmm pass will pattern-match and fuse)")
185+
with compile_model(
186+
Y, target, "./tmp", f"pattern_matched_b2b_bmm_{workdir_suffix}"
187+
) as module:
178188
y_ait = torch.empty_like(y_pt)
179189
module.run_with_tensors(
180190
{"Q": q_pt, "K": k_pt, "V": v_pt, "Bias": bias_pt},
@@ -189,13 +199,34 @@ def run_pattern_matching_example():
189199

190200

191201
def main():
202+
import argparse
203+
204+
parser = argparse.ArgumentParser(description="AITemplate classic_b2b_bmm example")
205+
parser.add_argument(
206+
"--use-cutedsl",
207+
action="store_true",
208+
default=False,
209+
help="Use CuTeDSL backend instead of CUTLASS C++ templates",
210+
)
211+
parser.add_argument(
212+
"--both",
213+
action="store_true",
214+
default=False,
215+
help="Run with both CUTLASS C++ and CuTeDSL backends",
216+
)
217+
args = parser.parse_args()
218+
192219
print("=" * 60)
193220
print("AITemplate classic_b2b_bmm Pattern Matching Example")
194221
print("=" * 60)
195222
print("\nDemonstrates automatic fusion of decomposed attention ops")
196223
print("into classic_b2b_bmm via the fuse_b2b_bmm compiler pass.")
197224

198-
run_pattern_matching_example()
225+
if args.both:
226+
run_pattern_matching_example(use_cutedsl=False)
227+
run_pattern_matching_example(use_cutedsl=True)
228+
else:
229+
run_pattern_matching_example(use_cutedsl=args.use_cutedsl)
199230

200231
print("\n" + "=" * 60)
201232
print("All tests passed!")
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# AITemplate classic_b2b_bmm: Graph Optimization & Code Generation Flow
2+
3+
## Overview
4+
5+
This document describes the end-to-end compilation flow when decomposed
6+
attention ops are automatically fused into a single `classic_b2b_bmm` kernel.
7+
8+
## Flow Diagram
9+
10+
```
11+
┌─────────────────────────────────────────────────────────────────┐
12+
│ User Code: build_decomposed_b2b_bmm_graph() │
13+
│ │
14+
│ Q ──► bmm_rcr(Q,K) ──► MUL(α₀) ──► ADD(bias) ──► SIGMOID │
15+
│ │ │
16+
│ MUL(α₁) │
17+
│ │ │
18+
│ bmm_rrr(score,V) ──► Y│
19+
└──────────────────────────┬──────────────────────────────────────┘
20+
21+
22+
┌──────────────────────────────────────────────────────────────────┐
23+
│ compile_model(Y, target, workdir, test_name) │
24+
│ [compiler.py] │
25+
│ │
26+
│ 1. toposort(output_tensors) │
27+
│ 2. name_graph(sorted_graph) │
28+
│ 3. optimize_graph(sorted_graph) ◄──────────────────────────┐ │
29+
│ │ │ │
30+
│ ├─ constant_folding │ │
31+
│ ├─ fuse_ops (elementwise fusions, etc.) │ │
32+
│ ├─ ★ fuse_b2b_bmm(sorted_graph) ◄───── PATTERN MATCH │ │
33+
│ │ │ │ │
34+
│ │ │ Matches chain: │ │
35+
│ │ │ bmm_rcr → MUL(const) → ADD(tensor) │ │
36+
│ │ │ → activation → [MUL(const)] → bmm_rrr │ │
37+
│ │ │ │ │
38+
│ │ │ Replaces with: │ │
39+
│ │ │ classic_b2b_bmm(Q, K, V, bias) │ │
40+
│ │ │ α₀, α₁, epilogue baked into op attrs │ │
41+
│ │ │ │ │
42+
│ │ └─ Removes 6 intermediate ops, 4+ intermediate │ │
43+
│ │ tensors │ │
44+
│ │ │ │
45+
│ ├─ memory_planning(sorted_graph) │ │
46+
│ └─ other passes... │ │
47+
│ │
48+
│ 4. codegen(sorted_graph, workdir) │
49+
│ │ │
50+
│ ├─ gen_function_src() │
51+
│ │ For each op (including classic_b2b_bmm): │
52+
│ │ ┌────────────────────────────────────────────────────┐ │
53+
│ │ │ op.gen_function() │ │
54+
│ │ │ → registry.get("cuda.classic_b2b_bmm.gen_function")│ │
55+
│ │ │ → Renders Jinja2 FUNC_TEMPLATE │ │
56+
│ │ │ → Writes <func_name>.cu │ │
57+
│ │ └────────────────────────────────────────────────────┘ │
58+
│ │ │
59+
│ ├─ ModelContainerGenerator │
60+
│ │ → func_decl(): function declarations │
61+
│ │ → func_call(): invocations in RunImpl() │
62+
│ │ → Writes model.cu, model_container.cu │
63+
│ │ │
64+
│ └─ copy_headers_and_csrc_to_workdir() │
65+
│ │
66+
│ 5. build(file_pairs, workdir, test_name) │
67+
│ │ │
68+
│ ├─ gen_makefile() │
69+
│ ├─ nvcc <func>.cu → <func>.obj │
70+
│ ├─ nvcc model.cu → model.obj │
71+
│ └─ nvcc -shared *.obj → test.so │
72+
│ │
73+
│ 6. Return Model(workdir) │
74+
└──────────────────────────┬───────────────────────────────────────┘
75+
76+
77+
┌──────────────────────────────────────────────────────────────────┐
78+
│ Runtime: module.run_with_tensors(inputs, outputs) │
79+
│ [model.py → Model class] │
80+
│ │
81+
│ 1. ctypes.CDLL loads test.so │
82+
│ 2. Sets input pointers + dynamic dims │
83+
│ 3. Calls RunImpl(stream) in C++ │
84+
│ → Invokes classic_b2b_bmm_func(output, Q, K, V, bias, │
85+
│ batch_size, num_heads, m0, k0, stream) │
86+
│ → Inside: instantiates B2bGemmBatched<...>, runs on GPU │
87+
│ 4. Returns output tensors │
88+
└──────────────────────────────────────────────────────────────────┘
89+
```
90+
91+
## Generated CUDA Code Structure
92+
93+
The backend codegen (`backend/cuda/b2b_bmm/classic_b2b_bmm.py`) produces:
94+
95+
### `<func_name>.cu` — Kernel Source
96+
```cpp
97+
#include "cutlass/cutlass.h"
98+
#include "classic_b2b_bmm/device/b2b_batched_gemm.h"
99+
100+
// Hardcoded tile sizes
101+
constexpr int ThreadblockM = 64, ThreadblockK = 32;
102+
constexpr int WarpM = 16, WarpK = 32;
103+
constexpr int N0 = <seq_len>, N1 = <head_dim>;
104+
105+
void <func_name>(void* output, void* query, void* key, void* value,
106+
void* bias, int64_t batch_size, int64_t num_heads,
107+
int64_t m0, int64_t k0, cudaStream_t stream) {
108+
// Type aliases, epilogue ops, B2bGemmBatched instantiation
109+
// Argument construction with batched/multi-head strides
110+
// Initialize and execute
111+
}
112+
```
113+
114+
### `model.cu` — Container
115+
```cpp
116+
class Model : public ModelBase<Model> {
117+
void RunImpl(StreamType stream) {
118+
// ... sets up pointers ...
119+
<func_name>(output, Q, K, V, bias, batch, heads, m0, k0, stream);
120+
}
121+
};
122+
```
123+
124+
## Key Files
125+
126+
| Component | File |
127+
|-----------|------|
128+
| Pattern matching | `compiler/transform/fuse_b2b_bmm.py` |
129+
| Op definition | `compiler/ops/b2b_bmm/classic_b2b_bmm.py` |
130+
| Base class | `compiler/ops/b2b_bmm/b2b_bmm_base.py` |
131+
| CUDA backend | `backend/cuda/b2b_bmm/classic_b2b_bmm.py` |
132+
| CUTLASS headers | `static/include/kernels/classic_b2b_bmm/` |
133+
| Compiler entry | `compiler/compiler.py` |
134+
| Code generation | `backend/codegen.py` |
135+
| Builder | `backend/builder.py` |
136+
| Runtime | `compiler/model.py` |

0 commit comments

Comments
 (0)