Skip to content

Commit 93956d6

Browse files
committed
fix bug: NULL_BLOCK_ID
1 parent 0e31de1 commit 93956d6

3 files changed

Lines changed: 14 additions & 15 deletions

File tree

vllm/model_executor/layers/fla/ops/fused_recurrent.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
106106
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
107107
else:
108108
i_t = 0
109-
# Load state index and check for invalid entries
109+
# Load state index and check for PAD_SLOT_ID (-1)
110110
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
111111
tl.int64
112112
)
113-
# Skip if state index is invalid (NULL_BLOCK_ID=0)
114-
if state_idx <= 0:
113+
# Skip if state index is invalid (PAD_SLOT_ID = -1)
114+
if state_idx < 0:
115115
return
116116
p_h0 = h0 + state_idx * stride_init_state_token
117117
else:
@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
150150

151151
# keep the states for multi-query tokens
152152
if INPLACE_FINAL_STATE:
153-
# Load state index and check for invalid entries
153+
# Load state index and check for PAD_SLOT_ID (-1)
154154
final_state_idx = tl.load(
155155
ssm_state_indices + i_n * stride_indices_seq + i_t
156156
).to(tl.int64)
157-
# Only store if state index is valid (not NULL_BLOCK_ID=0)
158-
if final_state_idx > 0:
157+
# Only store if state index is valid (not PAD_SLOT_ID)
158+
if final_state_idx >= 0:
159159
p_ht = ht + final_state_idx * stride_final_state_token
160160
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
161161
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@@ -292,8 +292,7 @@ def fused_recurrent_gated_delta_rule_packed_decode_kernel(
292292
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
293293
p_o = o + (i_n * HV + i_hv) * V + o_v
294294

295-
# Skip if state index is invalid (NULL_BLOCK_ID=0)
296-
if state_idx <= 0:
295+
if state_idx < 0:
297296
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
298297
tl.store(p_o, zero, mask=mask_v)
299298
return

vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
106106
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
107107
else:
108108
i_t = 0
109-
# Load state index and check for invalid entries
109+
# Load state index and check for PAD_SLOT_ID (-1)
110110
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
111111
tl.int64
112112
)
113-
# Skip if state index is invalid (NULL_BLOCK_ID=0)
114-
if state_idx <= 0:
113+
# Skip if state index is invalid (PAD_SLOT_ID = -1)
114+
if state_idx < 0:
115115
return
116116
p_h0 = h0 + state_idx * stride_init_state_token
117117
else:
@@ -155,12 +155,12 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
155155

156156
# keep the states for multi-query tokens
157157
if INPLACE_FINAL_STATE:
158-
# Load state index and check for invalid entries
158+
# Load state index and check for PAD_SLOT_ID (-1)
159159
final_state_idx = tl.load(
160160
ssm_state_indices + i_n * stride_indices_seq + i_t
161161
).to(tl.int64)
162-
# Only store if state index is valid (not NULL_BLOCK_ID=0)
163-
if final_state_idx > 0:
162+
# Only store if state index is valid (not PAD_SLOT_ID)
163+
if final_state_idx >= 0:
164164
p_ht = ht + final_state_idx * stride_final_state_token
165165
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
166166
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)

vllm/v1/attention/backends/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
4343

4444
PAD_SLOT_ID = -1
45-
NULL_BLOCK_ID = 0
45+
NULL_BLOCK_ID = -1
4646

4747

4848
def is_valid_kv_cache_layout(value: str) -> bool:

0 commit comments

Comments
 (0)