Skip to content

Commit c7face3

Browse files
committed
xe: sdpa, ggemm, gmlp: avoid copies
1 parent 5d6fe4d commit c7face3

11 files changed

Lines changed: 18 additions & 17 deletions

File tree

src/gpu/intel/gated_mlp/micro_horz.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ status_t micro_horz_t::pd_t::init_microkernels(
197197
&& (problem.Tb != gemmstone::Type::invalid),
198198
status::unimplemented, "Incompatible A/B types in uGEMM.");
199199

200-
auto problem_wgu = problem;
200+
auto problem_wgu = std::move(problem);
201201
problem_wgu.A.layout = gemmstone::MatrixLayout::T;
202202
problem_wgu.B.layout = gemmstone::MatrixLayout::Pr;
203203
problem_wgu.C.layout = gemmstone::MatrixLayout::T;

src/gpu/intel/gated_mlp/ref.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ struct ref_t : public primitive_t {
161161

162162
status_t execute(const exec_ctx_t &ctx) const override {
163163
auto prep_weights_and_run
164-
= [&](exec_args_t &args, int idx,
164+
= [&](exec_args_t &&args, int idx,
165165
const std::shared_ptr<impl::primitive_t> &prim) {
166166
args[DNNL_ARG_WEIGHTS] = ctx.args().at(idx);
167167
if (!pd()->attr()->scales_.has_default_values(idx))
@@ -201,7 +201,8 @@ struct ref_t : public primitive_t {
201201
exec_args_t args;
202202
args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_SRC);
203203
args[DNNL_ARG_DST] = memory_arg_t {inter_src_mem.get(), false};
204-
CHECK(prep_weights_and_run(args, DNNL_ARG_WEIGHTS_UP, gemm_up_));
204+
CHECK(prep_weights_and_run(
205+
std::move(args), DNNL_ARG_WEIGHTS_UP, gemm_up_));
205206
} while (false);
206207
do {
207208
exec_args_t args;
@@ -211,14 +212,14 @@ struct ref_t : public primitive_t {
211212
args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1]
212213
= memory_arg_t {inter_src_mem.get(), true};
213214
CHECK(prep_weights_and_run(
214-
args, DNNL_ARG_WEIGHTS_GATE, gemm_gate_));
215+
std::move(args), DNNL_ARG_WEIGHTS_GATE, gemm_gate_));
215216
} while (false);
216217
do {
217218
exec_args_t args;
218219
args[DNNL_ARG_SRC] = memory_arg_t {inter_wei_mem.get(), true};
219220
args[DNNL_ARG_DST] = ctx.args().at(DNNL_ARG_DST);
220221
CHECK(prep_weights_and_run(
221-
args, DNNL_ARG_WEIGHTS_DOWN, gemm_down_));
222+
std::move(args), DNNL_ARG_WEIGHTS_DOWN, gemm_down_));
222223
} while (false);
223224

224225
return status::success;

src/gpu/intel/gemm/jit/gen_kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ status_t gen_desc_t::finalize(const char *tags) {
166166
problem_.beta = stringToScalar(val);
167167

168168
ovr_strategy = ss.str().substr(ss.tellg()); // remaining string
169-
parseStrategy(ovr_strategy.c_str(), hw_, problem_, strategy_);
169+
parseStrategy(ovr_strategy, hw_, problem_, strategy_);
170170

171171
// TODO: override derived values in aux_params_ in a way that's
172172
// consistent with the kernel evaluator (typically requires extra

src/gpu/intel/gemm/jit/generator/microkernel/shim.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ std::string generateShim(const Package &package, HostLanguage language,
231231
auto nargs = int(pargs.size());
232232

233233
/* Match up protocol settings with microkernel settings */
234-
auto psettings = package.protocol.settings();
234+
const auto &psettings = package.protocol.settings();
235235
auto settings = matchProtocol(psettings, package.settings);
236236

237237
/* Collect actual argument types */

src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ GEMMOptions GEMMOptions::transpose() const {
147147
return ret;
148148
}
149149

150-
std::string strategyToString(HW hw, GEMMProblem problem, GEMMStrategy strategy) {
150+
std::string strategyToString(HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) {
151151
std::stringstream ss;
152152
ss << problem.toString() << " "
153153
<< std::to_string(strategy.unroll[LoopM])

src/gpu/intel/gemm/jit/generator/strategy_parser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ static void getTiling(std::stringstream &s, MatrixAddressingStrategy &astrategy)
170170
}
171171
}
172172

173-
void parseStrategy(const char *str, HW hw, const GEMMProblem &problem, GEMMStrategy &strategy)
173+
void parseStrategy(const std::string &str, HW hw, const GEMMProblem &problem, GEMMStrategy &strategy)
174174
{
175175
std::stringstream s(str);
176176
s.imbue(std::locale::classic());

src/gpu/intel/gemm/jit/include/gemmstone/strategy_parser.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
GEMMSTONE_NAMESPACE_START
2727

28-
void parseStrategy(const char *str, ngen::HW hw, const GEMMProblem &problem, GEMMStrategy &strategy);
28+
void parseStrategy(const std::string &str, ngen::HW hw, const GEMMProblem &problem, GEMMStrategy &strategy);
2929

3030
void adjustStrategy(ngen::HW hw, const GEMMProblem &problem, GEMMStrategy &strategy, const char *tags = nullptr);
3131

src/gpu/intel/matmul/grouped_micro_gemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ status_t grouped_micro_gemm_t::pd_t::init_microkernels(impl::engine_t *engine) {
190190
Scalar alpha(a), beta(b);
191191
std::string strategyString;
192192
std::getline(ss >> std::ws, strategyString);
193-
parseStrategy(strategyString.c_str(), hw, problem, strat);
193+
parseStrategy(strategyString, hw, problem, strat);
194194
adjustStrategy(hw, problem, strat);
195195
}
196196
strategyGRFs_ = strat.GRFs;

src/gpu/intel/sdpa/micro.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) {
625625
ukernel_params.sizes_kq = {heuristic_sizes};
626626

627627
/* Set up GEMMProblem structure for second GEMM: V * S */
628-
auto problem_vs = std::move(problem);
628+
auto problem_vs = problem;
629629
problem_vs.Tc = problem_vs.Ts
630630
= (vs_acc_dt() == data_type::f16) ? Type::f16 : Type::f32;
631631

@@ -723,7 +723,7 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) {
723723
ukernel_params.opts_qdSt = {opts_qdSt};
724724

725725
// dS * K
726-
auto problem_ktq = problem;
726+
auto problem_ktq = std::move(problem);
727727
problem_ktq.Ta_ext
728728
= convert_dnnl_to_kernel_type(desc()->key_md()->data_type);
729729

@@ -1180,7 +1180,7 @@ status_t micro_fwd_params_t::get_kernel_ctx(
11801180
ss >> strat.unroll[1];
11811181
std::string strategyString;
11821182
std::getline(ss >> std::ws, strategyString);
1183-
parseStrategy(strategyString.c_str(), hw, problem_kq, strat);
1183+
parseStrategy(strategyString, hw, problem_kq, strat);
11841184
adjustStrategy(hw, problem_kq, strat);
11851185
}
11861186
};
@@ -1208,7 +1208,7 @@ status_t micro_fwd_params_t::get_kernel_ctx(
12081208
ss >> strat.unroll[1];
12091209
std::string strategyString;
12101210
std::getline(ss >> std::ws, strategyString);
1211-
parseStrategy(strategyString.c_str(), hw, problem_vs, strat);
1211+
parseStrategy(strategyString, hw, problem_vs, strat);
12121212
adjustStrategy(hw, problem_vs, strat);
12131213
}
12141214
};

tests/gtests/internals/test_sdpa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,7 @@ std::chrono::nanoseconds prim_sdpa_quant_bwd(const sdpa_dims_t &p,
16771677
strm.wait();
16781678
if (dropout_mask_fwd_out != nullptr && p.dropout.enabled()
16791679
&& p.dropout.has_output_mask()) {
1680-
*dropout_mask_fwd_out = softmax_dropout_mask;
1680+
*dropout_mask_fwd_out = std::move(softmax_dropout_mask);
16811681
}
16821682

16831683
#if DEBUG_PRINT_MEM

0 commit comments

Comments
 (0)