Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions picolm/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -601,17 +601,29 @@ float *model_forward(model_t *m, int token, int pos) {
rope(s->q, k_tmp, head_dim, n_heads, n_kv_heads, cos_pos, sin_pos);

/* Convert K to FP16 and store */
#ifdef PICOLM_FP16_HW
for (int d = 0; d < kv_dim; d += 4) {
f32x4_to_fp16(key_pos_fp16 + d, vld1q_f32(k_tmp + d));
}
#else
for (int d = 0; d < kv_dim; d++) {
key_pos_fp16[d] = fp32_to_fp16(k_tmp[d]);
}
#endif

/* V projection -> store directly as FP16 */
float *v_tmp = s->xb2;
matmul(v_tmp, s->xb, lw->attn_v, dim, kv_dim, lw->type_attn_v);
uint16_t *val_pos_fp16 = vcache_layer + (size_t)pos * kv_dim;
#ifdef PICOLM_FP16_HW
for (int d = 0; d < kv_dim; d += 4) {
f32x4_to_fp16(val_pos_fp16 + d, vld1q_f32(v_tmp + d));
}
#else
for (int d = 0; d < kv_dim; d++) {
val_pos_fp16[d] = fp32_to_fp16(v_tmp[d]);
}
#endif

/* ---- Flash Attention (online softmax) ----
*
Expand Down Expand Up @@ -647,10 +659,18 @@ float *model_forward(model_t *m, int token, int pos) {
for (int t = 0; t <= pos; t++) {
/* Compute score: dot(Q_h, K_t) / sqrt(head_dim) */
const uint16_t *kt = kcache_layer + (size_t)t * kv_dim + kv_h * head_dim;
#ifdef PICOLM_FP16_HW
float32x4_t dot_acc = vdupq_n_f32(0);
for (int d = 0; d < head_dim; d += 4) {
dot_acc = vmlaq_f32(dot_acc, vld1q_f32(qh + d), fp16x4_to_f32(kt + d));
}
float score = vaddvq_f32(dot_acc);
#else
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {
score += qh[d] * fp16_to_fp32(kt[d]);
}
#endif
score /= sqrtf((float)head_dim);

/* Online softmax update */
Expand All @@ -659,24 +679,47 @@ float *model_forward(model_t *m, int token, int pos) {
if (score > max_score) {
float correction = expf(max_score - score);
sum_exp = sum_exp * correction + 1.0f;
#ifdef PICOLM_FP16_HW
float32x4_t corr_v = vdupq_n_f32(correction);
for (int d = 0; d < head_dim; d += 4) {
float32x4_t a = vld1q_f32(acc + d);
vst1q_f32(acc + d, vmlaq_f32(fp16x4_to_f32(vt + d), a, corr_v));
}
#else
for (int d = 0; d < head_dim; d++) {
acc[d] = acc[d] * correction + fp16_to_fp32(vt[d]);
}
#endif
max_score = score;
} else {
float w = expf(score - max_score);
sum_exp += w;
#ifdef PICOLM_FP16_HW
float32x4_t w_v = vdupq_n_f32(w);
for (int d = 0; d < head_dim; d += 4) {
float32x4_t a = vld1q_f32(acc + d);
vst1q_f32(acc + d, vmlaq_f32(a, fp16x4_to_f32(vt + d), w_v));
}
#else
for (int d = 0; d < head_dim; d++) {
acc[d] += w * fp16_to_fp32(vt[d]);
}
#endif
}
}

/* Normalize */
float inv_sum = 1.0f / sum_exp;
#ifdef PICOLM_FP16_HW
float32x4_t inv_v = vdupq_n_f32(inv_sum);
for (int d = 0; d < head_dim; d += 4) {
vst1q_f32(xbh + d, vmulq_f32(vld1q_f32(acc + d), inv_v));
}
#else
for (int d = 0; d < head_dim; d++) {
xbh[d] = acc[d] * inv_sum;
}
#endif
}

/* Output projection */
Expand Down
9 changes: 9 additions & 0 deletions picolm/quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ static inline float vaddvq_f32_compat(float32x4_t v) {
return vget_lane_f32(vpadd_f32(r, r), 0);
#endif
}
#if defined(__aarch64__)
#define PICOLM_FP16_HW 1
static inline float32x4_t fp16x4_to_f32(const uint16_t *p) {
return vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(p)));
}
static inline void f32x4_to_fp16(uint16_t *p, float32x4_t v) {
vst1_u16(p, vreinterpret_u16_f16(vcvt_f16_f32(v)));
}
#endif
#endif

#if defined(__SSE2__) || (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_AMD64)))
Expand Down
19 changes: 19 additions & 0 deletions picolm/tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,31 @@ void rope(float *q, float *k, int head_dim, int n_heads, int n_kv_heads,
/* Apply RoPE to all KV heads */
for (int h = 0; h < n_kv_heads; h++) {
float *kh = k + h * head_dim;
#ifdef PICOLM_NEON
int i = 0;
for (; i + 3 < half; i += 4) {
float32x4x2_t kv = vld2q_f32(kh + i * 2);
float32x4_t cv = vld1q_f32(cos_pos + i);
float32x4_t sv = vld1q_f32(sin_pos + i);
float32x4_t new_even = vmlsq_f32(vmulq_f32(kv.val[0], cv), kv.val[1], sv);
float32x4_t new_odd = vmlaq_f32(vmulq_f32(kv.val[0], sv), kv.val[1], cv);
float32x4x2_t result = {{ new_even, new_odd }};
vst2q_f32(kh + i * 2, result);
}
for (; i < half; i++) {
float k0 = kh[i * 2];
float k1 = kh[i * 2 + 1];
kh[i * 2] = k0 * cos_pos[i] - k1 * sin_pos[i];
kh[i * 2 + 1] = k0 * sin_pos[i] + k1 * cos_pos[i];
}
#else
for (int i = 0; i < half; i++) {
float k0 = kh[i * 2];
float k1 = kh[i * 2 + 1];
kh[i * 2] = k0 * cos_pos[i] - k1 * sin_pos[i];
kh[i * 2 + 1] = k0 * sin_pos[i] + k1 * cos_pos[i];
}
#endif
}
}

Expand Down