Skip to content

Commit 5b52dfb

Browse files
committed
cpu: rv64: remove wei pack inside primitive
1 parent bd2aaa0 commit 5b52dfb

2 files changed

Lines changed: 16 additions & 30 deletions

File tree

src/cpu/rv64/rvv_winograd_convolution.cpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,7 @@ struct jit_wino_output_transform_t : public jit_generator_t {
537537

538538
} // namespace
539539

540-
status_t rvv_wino_resource_t::configure(
541-
size_t weight_buf_size, const rvv_winograd_conf_t &conf) {
542-
weight_buf_.reset(new float[weight_buf_size]);
543-
if (!weight_buf_) return status::out_of_memory;
544-
540+
status_t rvv_wino_resource_t::configure(const rvv_winograd_conf_t &conf) {
545541
auto *input = new jit_wino_input_transform_t(conf);
546542
input->create_kernel();
547543
input_xform_.reset(input);
@@ -629,10 +625,10 @@ status_t rvv_winograd_init_conf(rvv_winograd_conf_t &conf,
629625
conf.wspec.V_buffer_size = conf.wspec.n_gemms * conf.wspec.input_ld_batch;
630626
conf.wspec.M_buffer_size = conf.wspec.n_gemms * conf.wspec.output_ld_batch;
631627

632-
// Scratchpad: V and M buffers for single-thread execution
633-
// Weight buffer is in persistent resource_t (not scratchpad)
628+
// Scratchpad: U (transformed weights), V (transformed input), M (GEMM output)
634629
using namespace memory_tracking::names;
635630

631+
scratchpad.book<float>(key_wino_U, conf.wspec.weight_matrix_size);
636632
scratchpad.book<float>(key_wino_V, conf.wspec.V_buffer_size);
637633
scratchpad.book<float>(key_wino_M, conf.wspec.M_buffer_size);
638634

@@ -651,24 +647,20 @@ status_t rvv_wino_convolution_fwd_t::execute_forward(
651647
const auto scratchpad = ctx.get_scratchpad_grantor();
652648
const auto *brg_kernel = pd()->brg_kernel_.get();
653649

654-
// Get persistent weight buffer from resource (cached across execute calls)
655-
auto *wino_resource
656-
= ctx.get_resource_mapper()->get<rvv_wino_resource_t>(this);
657-
float *transformed_weights = wino_resource->get_weight_buffer();
658-
659-
// Transform weights on first execute, cache for subsequent calls
660-
if (!wino_resource->weights_valid()) {
661-
compute_filter_transform_3x3_to_4x4_gemm_layout(weights,
662-
transformed_weights, conf.wspec.N, conf.wspec.K,
663-
conf.wspec.weight_ic_rounded, conf.wspec.weight_oc_rounded);
664-
wino_resource->set_weights_valid();
665-
}
666-
667650
using namespace memory_tracking::names;
651+
652+
// Transform weights into scratchpad buffer every execute (like x64 brgemm)
653+
float *transformed_weights = scratchpad.template get<float>(key_wino_U);
654+
compute_filter_transform_3x3_to_4x4_gemm_layout(weights,
655+
transformed_weights, conf.wspec.N, conf.wspec.K,
656+
conf.wspec.weight_ic_rounded, conf.wspec.weight_oc_rounded);
657+
668658
float *V = scratchpad.template get<float>(key_wino_V);
669659
float *M = scratchpad.template get<float>(key_wino_M);
670660

671-
// JIT kernels created in create_resource(), retrieved from resource
661+
// JIT kernels persisted in resource across execute() calls
662+
auto *wino_resource
663+
= ctx.get_resource_mapper()->get<rvv_wino_resource_t>(this);
672664
auto *input_xform = wino_resource->get_input_xform();
673665
auto *output_xform = wino_resource->get_output_xform();
674666

src/cpu/rv64/rvv_winograd_convolution.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,26 +98,20 @@ status_t rvv_winograd_init_conf(rvv_winograd_conf_t &conf,
9898
const memory_desc_t &dst_md, const memory_desc_t &bias_md,
9999
const primitive_attr_t &attr);
100100

101-
// Resource for persistent weight transform buffer.
102-
// Weights are transformed on first execute() and cached for reuse.
101+
// Resource for JIT kernel persistence across execute() calls.
103102
struct rvv_wino_resource_t : public resource_t {
104103
rvv_wino_resource_t() = default;
105104

106-
status_t configure(size_t weight_buf_size, const rvv_winograd_conf_t &conf);
105+
status_t configure(const rvv_winograd_conf_t &conf);
107106

108-
float *get_weight_buffer() const { return weight_buf_.get(); }
109107
jit_generator_t *get_input_xform() const { return input_xform_.get(); }
110108
jit_generator_t *get_output_xform() const { return output_xform_.get(); }
111-
bool weights_valid() const { return weights_valid_; }
112-
void set_weights_valid() const { weights_valid_ = true; }
113109

114110
DNNL_DISALLOW_COPY_AND_ASSIGN(rvv_wino_resource_t);
115111

116112
private:
117-
std::unique_ptr<float[]> weight_buf_;
118113
std::unique_ptr<jit_generator_t> input_xform_;
119114
std::unique_ptr<jit_generator_t> output_xform_;
120-
mutable bool weights_valid_ = false;
121115
};
122116

123117
struct rvv_wino_convolution_fwd_t : public primitive_t {
@@ -255,7 +249,7 @@ struct rvv_wino_convolution_fwd_t : public primitive_t {
255249
if (mapper.has_resource(this)) return status::success;
256250
auto r = utils::make_unique<rvv_wino_resource_t>();
257251
if (!r) return status::out_of_memory;
258-
CHECK(r->configure(pd()->conf_.wspec.weight_matrix_size, pd()->conf_));
252+
CHECK(r->configure(pd()->conf_));
259253
mapper.add(this, std::move(r));
260254
return status::success;
261255
}

0 commit comments

Comments
 (0)