Skip to content

Commit 1999406

Browse files
committed
examples: graph: add int8 sdpa example
1 parent b17fc49 commit 1999406

2 files changed

Lines changed: 369 additions & 0 deletions

File tree

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ if(NOT ONEDNN_BUILD_GRAPH)
109109
${CMAKE_CURRENT_SOURCE_DIR}/graph/gated_mlp_int4.cpp
110110
${CMAKE_CURRENT_SOURCE_DIR}/graph/sdpa_bottom_right_causal_mask.cpp
111111
${CMAKE_CURRENT_SOURCE_DIR}/graph/gqa_training.cpp
112+
${CMAKE_CURRENT_SOURCE_DIR}/graph/int8_sdpa.cpp
112113
)
113114
endif()
114115

examples/graph/int8_sdpa.cpp

Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
/*******************************************************************************
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include <cassert>
18+
#include <chrono>
19+
#include <iomanip>
20+
#include <iostream>
21+
#include <memory>
22+
#include <random>
23+
#include <string>
24+
#include <vector>
25+
26+
#include "oneapi/dnnl/dnnl.hpp"
27+
#include "oneapi/dnnl/dnnl_graph.hpp"
28+
29+
#include "graph_example_utils.hpp"
30+
31+
using namespace dnnl;
32+
33+
using namespace dnnl::graph;
34+
using layout_type = logical_tensor::layout_type;
35+
using data_type = logical_tensor::data_type;
36+
using dim = logical_tensor::dim;
37+
using dims = logical_tensor::dims;
38+
39+
struct sdpa_dims_t {
40+
dim mb;
41+
dim seq_len;
42+
dim head_num;
43+
dim head_size;
44+
};
45+
46+
static const int min_runs = 4;
47+
48+
// this is changed from the fill_random() function in matmul_perf.cpp.
49+
void fill_random(std::vector<float> &out) {
50+
static std::vector<float> random_data_f;
51+
constexpr size_t nrand = 1037;
52+
53+
if (random_data_f.empty()) {
54+
std::mt19937 generator;
55+
std::uniform_real_distribution<float> dist_f(-1.0f, 1.0f);
56+
57+
random_data_f.resize(nrand);
58+
for (auto &d : random_data_f)
59+
d = dist_f(generator);
60+
}
61+
62+
for (size_t i = 0; i < out.size(); i += nrand) {
63+
size_t chunk = std::min(nrand, out.size() - i);
64+
std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float));
65+
}
66+
}
67+
68+
// initialize the mask with first 3/4 elements with 0s and the last 1/4 elements
69+
// with -inf.
70+
void fill_mask(std::vector<float> &mask, size_t seq_len) {
71+
const size_t pos = seq_len * 3 / 4;
72+
for (size_t i = 0; i < mask.size(); ++i) {
73+
if (i % seq_len < pos)
74+
mask[i] = 0.f;
75+
else
76+
mask[i] = -1 * std::numeric_limits<float>::infinity();
77+
}
78+
}
79+
80+
const char *get_type_string(logical_tensor::data_type dt) {
81+
const char *type_string = "unknown";
82+
83+
#define TYPE_CASE(T) \
84+
if (dt == logical_tensor::data_type::T) type_string = #T;
85+
TYPE_CASE(f16);
86+
TYPE_CASE(f32);
87+
TYPE_CASE(bf16);
88+
TYPE_CASE(u8);
89+
TYPE_CASE(s8);
90+
#undef TYPE_CASE
91+
92+
return type_string;
93+
}
94+
95+
void print_test_case(logical_tensor::data_type dt, const sdpa_dims_t &p) {
96+
std::cout << '[' << std::setw(4) << get_type_string(dt);
97+
std::cout << " mb = " << p.mb << ", seq_len = " << p.seq_len
98+
<< ", head_num = " << p.head_num
99+
<< ", head_size = " << p.head_size;
100+
std::cout << "] " << std::flush;
101+
}
102+
103+
void bench_int8_sdpa(
104+
engine::kind ekind, const sdpa_dims_t &p, double time_limit = 0.) {
105+
const bool quick_test = (time_limit == 0.);
106+
print_test_case(data_type::u8, p);
107+
108+
allocator alloc = create_allocator(ekind);
109+
110+
// Create execution dnnl::engine.
111+
dnnl::engine eng = make_engine_with_allocator(ekind, 0, alloc);
112+
// Create dnnl::stream.
113+
dnnl::stream strm(eng);
114+
115+
// Prepare input and output shapes to construct the sdpa graph.
116+
const dims qkv_sz = {p.mb, p.head_num, p.seq_len, p.head_size};
117+
const dims score_sz = {p.mb, p.head_num, p.seq_len, p.seq_len};
118+
const dims scale_sz = {1};
119+
const dims mask_sz = {p.mb, 1, 1, p.seq_len};
120+
121+
// Incremental IDs used to create logical tensors and operations.
122+
size_t id = 0;
123+
124+
// insert the dequant for u8 query to f32 query
125+
auto q_u8
126+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
127+
auto q_f32 = logical_tensor(
128+
id++, data_type::f32, qkv_sz, layout_type::strided);
129+
auto q_deq = op(id++, op::kind::Dequantize, "q_deq");
130+
q_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
131+
q_deq.set_attr<float>(op::attr::scales, 0.25f);
132+
q_deq.set_attr<int64_t>(op::attr::zps, 128);
133+
q_deq.add_input(q_u8);
134+
q_deq.add_output(q_f32);
135+
136+
// insert the dequant for u8 key to f32 key
137+
auto k_u8
138+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
139+
auto k_f32 = logical_tensor(
140+
id++, data_type::f32, qkv_sz, layout_type::strided);
141+
auto k_deq = op(id++, op::kind::Dequantize, "k_deq");
142+
k_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
143+
k_deq.set_attr<float>(op::attr::scales, 0.25f);
144+
k_deq.set_attr<int64_t>(op::attr::zps, 128);
145+
k_deq.add_input(k_u8);
146+
k_deq.add_output(k_f32);
147+
148+
// score = query x key.T.
149+
auto score = logical_tensor(
150+
id++, data_type::f32, score_sz, layout_type::strided);
151+
auto bmm1 = op(id++, op::kind::MatMul, "bmm1");
152+
bmm1.set_attr<bool>(op::attr::transpose_b, true);
153+
bmm1.add_inputs({q_f32, k_f32});
154+
bmm1.add_output(score);
155+
156+
// scaled_score = score / scale
157+
auto scale = logical_tensor(
158+
id++, data_type::f32, scale_sz, layout_type::strided);
159+
auto scaled_score = logical_tensor(
160+
id++, data_type::f32, score_sz, layout_type::strided);
161+
auto scale_div = op(id++, op::kind::Divide, "scale_div");
162+
scale_div.add_inputs({score, scale});
163+
scale_div.add_outputs({scaled_score});
164+
165+
// masked_score = scaled_score + mask
166+
auto mask = logical_tensor(
167+
id++, data_type::f32, mask_sz, layout_type::strided);
168+
auto masked_score = logical_tensor(
169+
id++, data_type::f32, score_sz, layout_type::strided);
170+
auto mask_add = op(id++, op::kind::Add, "mask_add");
171+
mask_add.add_inputs({scaled_score, mask});
172+
mask_add.add_outputs({masked_score});
173+
174+
// attention_probs = softmax(masked_score)
175+
auto probs = logical_tensor(
176+
id++, data_type::f32, score_sz, layout_type::strided);
177+
auto softmax = op(id++, op::kind::SoftMax, "softmax");
178+
softmax.set_attr<int64_t>(op::attr::axis, -1);
179+
softmax.add_inputs({masked_score});
180+
softmax.add_outputs({probs});
181+
182+
// quantize the probs from f32 to u8
183+
auto probs_u8 = logical_tensor(
184+
id++, data_type::u8, score_sz, layout_type::strided);
185+
auto p_quant = op(id++, op::kind::Quantize, "p_quant");
186+
p_quant.set_attr<std::string>(op::attr::qtype, "per_tensor");
187+
p_quant.set_attr<float>(op::attr::scales, 0.25f);
188+
p_quant.set_attr<int64_t>(op::attr::zps, 128);
189+
p_quant.add_input(probs);
190+
p_quant.add_output(probs_u8);
191+
192+
// dequant the probs from u8 to f32
193+
auto probs_f32 = logical_tensor(
194+
id++, data_type::f32, score_sz, layout_type::strided);
195+
auto p_deq = op(id++, op::kind::Dequantize, "p_deq");
196+
p_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
197+
p_deq.set_attr<float>(op::attr::scales, 0.25f);
198+
p_deq.set_attr<int64_t>(op::attr::zps, 128);
199+
p_deq.add_input(probs_u8);
200+
p_deq.add_output(probs_f32);
201+
202+
// dequant the value from u8 to f32
203+
auto v_u8
204+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
205+
auto v_f32 = logical_tensor(
206+
id++, data_type::f32, qkv_sz, layout_type::strided);
207+
auto v_deq = op(id++, op::kind::Dequantize, "v_deq");
208+
v_deq.set_attr<std::string>(op::attr::qtype, "per_tensor");
209+
v_deq.set_attr<float>(op::attr::scales, 0.25f);
210+
v_deq.set_attr<int64_t>(op::attr::zps, 128);
211+
v_deq.add_input(v_u8);
212+
v_deq.add_output(v_f32);
213+
214+
// attention_output = attention_probs x value.
215+
auto output = logical_tensor(
216+
id++, data_type::f32, qkv_sz, layout_type::strided);
217+
auto bmm2 = op(id++, op::kind::MatMul, "bmm2");
218+
bmm2.add_inputs({probs_f32, v_f32});
219+
bmm2.add_outputs({output});
220+
221+
// quantize the output from f32 to u8
222+
auto output_u8
223+
= logical_tensor(id++, data_type::u8, qkv_sz, layout_type::strided);
224+
auto o_quant = op(id++, op::kind::Quantize, "o_quant");
225+
o_quant.set_attr<std::string>(op::attr::qtype, "per_tensor");
226+
o_quant.set_attr<float>(op::attr::scales, 0.25f);
227+
o_quant.set_attr<int64_t>(op::attr::zps, 128);
228+
o_quant.add_input(output);
229+
o_quant.add_output(output_u8);
230+
231+
// Construct a sdpa graph with engine kind and operations.
232+
dnnl::graph::graph sdpa(ekind);
233+
sdpa.add_op(q_deq);
234+
sdpa.add_op(k_deq);
235+
sdpa.add_op(bmm1);
236+
sdpa.add_op(scale_div);
237+
sdpa.add_op(mask_add);
238+
sdpa.add_op(softmax);
239+
sdpa.add_op(p_quant);
240+
sdpa.add_op(p_deq);
241+
sdpa.add_op(v_deq);
242+
sdpa.add_op(bmm2);
243+
sdpa.add_op(o_quant);
244+
sdpa.finalize();
245+
246+
// Get partitions from the sdpa graph.
247+
std::vector<partition> partitions = sdpa.get_partitions();
248+
// This is just for oneDNN testing purpose.
249+
if (partitions.size() != 1) {
250+
std::cout << "unsupported sdpa" << std::endl;
251+
return;
252+
}
253+
254+
// Compile the partition with inputs, outputs, and an engine.
255+
compiled_partition cp = partitions[0].compile(
256+
{q_u8, k_u8, scale, mask, v_u8}, {output_u8}, eng);
257+
258+
// Create tensor objects
259+
auto ts_query = tensor(q_u8, eng);
260+
auto ts_key = tensor(k_u8, eng);
261+
auto ts_scale = tensor(scale, eng);
262+
auto ts_mask = tensor(mask, eng);
263+
auto ts_value = tensor(v_u8, eng);
264+
auto ts_output = tensor(output_u8, eng);
265+
266+
// Allocate user data.
267+
std::vector<float> query_data(product(qkv_sz));
268+
std::vector<float> key_data(product(qkv_sz));
269+
std::vector<float> scale_data(
270+
product(scale_sz), (float)std::sqrt(p.head_size));
271+
std::vector<float> mask_data(product(mask_sz));
272+
std::vector<float> value_data(product(qkv_sz));
273+
std::vector<float> output_data(product(qkv_sz));
274+
275+
fill_random(query_data);
276+
fill_random(key_data);
277+
fill_random(value_data);
278+
fill_mask(mask_data, static_cast<size_t>(p.seq_len));
279+
280+
// Write data to tensor object's handle.
281+
write_to_dnnl_tensor(query_data.data(), ts_query);
282+
write_to_dnnl_tensor(key_data.data(), ts_key);
283+
write_to_dnnl_tensor(scale_data.data(), ts_scale);
284+
write_to_dnnl_tensor(mask_data.data(), ts_mask);
285+
write_to_dnnl_tensor(value_data.data(), ts_value);
286+
287+
// Warmup run.
288+
// Execute the compiled partition of sdpa.
289+
cp.execute(
290+
strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output});
291+
292+
// Wait for the computation to finish.
293+
strm.wait();
294+
295+
// First run.
296+
auto start_first = std::chrono::steady_clock::now();
297+
cp.execute(
298+
strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output});
299+
strm.wait();
300+
auto end_first = std::chrono::steady_clock::now();
301+
std::chrono::duration<double, std::milli> dur_first
302+
= end_first - start_first;
303+
304+
if (quick_test) return;
305+
306+
// Timing runs.
307+
const int runs = std::max(min_runs, int(time_limit / dur_first.count()));
308+
auto start = std::chrono::steady_clock::now();
309+
for (int i = 0; i <= runs; i++)
310+
cp.execute(strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value},
311+
{ts_output});
312+
strm.wait();
313+
auto end = std::chrono::steady_clock::now();
314+
std::chrono::duration<double, std::milli> duration = end - start;
315+
316+
// Display the results.
317+
double avg_time = (duration.count() - dur_first.count()) / runs;
318+
std::cout << "graph runs: " << runs + 1 << "; ";
319+
std::cout << "avg_time: " << avg_time << " ms" << std::endl;
320+
}
321+
322+
void bad_args() {
323+
std::cerr << "Usage: graph-int8-sdpa-cpp [cpu|gpu]\n"
324+
" graph-int8-sdpa-cpp [cpu|gpu] <mb> <seq_len> "
325+
"<head_num> <head_size>\n\n";
326+
throw std::invalid_argument("Incorrect input arguments.");
327+
}
328+
329+
void bench(engine::kind ekind, const sdpa_dims_t &p, double time_limit = 0.) {
330+
try {
331+
bench_int8_sdpa(ekind, p, time_limit);
332+
get_mem_pool().clear();
333+
} catch (dnnl::error &e) {
334+
// Catch and report unimplemented cases.
335+
if (e.status == dnnl_unimplemented) {
336+
std::cout << "unsupported sdpa: " << std::endl;
337+
} else
338+
throw;
339+
}
340+
}
341+
342+
void sdpa_perf(engine::kind ekind, int argc, char **argv) {
343+
// default testing parameters
344+
sdpa_dims_t params = {32, 384, 16, 64};
345+
346+
if (argc > 2) {
347+
if (argc == 6) {
348+
params.mb = std::atoi(argv[2]);
349+
params.seq_len = std::atoi(argv[3]);
350+
params.head_num = std::atoi(argv[4]);
351+
params.head_size = std::atoi(argv[5]);
352+
} else {
353+
bad_args();
354+
}
355+
356+
if (params.mb <= 0 || params.seq_len <= 0 || params.head_num <= 0
357+
|| params.head_size <= 0) {
358+
bad_args();
359+
}
360+
}
361+
362+
bench(ekind, params, 2000.0 /*ms*/);
363+
}
364+
365+
int main(int argc, char **argv) {
366+
return handle_example_errors(
367+
sdpa_perf, parse_engine_kind(argc, argv, 4), argc, argv);
368+
}

0 commit comments

Comments
 (0)