@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
106106 i_t = tl .load (num_accepted_tokens + i_n ).to (tl .int64 ) - 1
107107 else :
108108 i_t = 0
109- # Load state index and check for invalid entries
109+ # Load state index and check for PAD_SLOT_ID (-1)
110110 state_idx = tl .load (ssm_state_indices + i_n * stride_indices_seq + i_t ).to (
111111 tl .int64
112112 )
113- # Skip if state index is invalid (NULL_BLOCK_ID=0 )
114- if state_idx <= 0 :
113+ # Skip if state index is invalid (PAD_SLOT_ID = -1 )
114+ if state_idx < 0 :
115115 return
116116 p_h0 = h0 + state_idx * stride_init_state_token
117117 else :
@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
150150
151151 # keep the states for multi-query tokens
152152 if INPLACE_FINAL_STATE :
153- # Load state index and check for invalid entries
153+ # Load state index and check for PAD_SLOT_ID (-1)
154154 final_state_idx = tl .load (
155155 ssm_state_indices + i_n * stride_indices_seq + i_t
156156 ).to (tl .int64 )
157- # Only store if state index is valid (not NULL_BLOCK_ID=0 )
158- if final_state_idx > 0 :
157+ # Only store if state index is valid (not PAD_SLOT_ID )
158+ if final_state_idx >= 0 :
159159 p_ht = ht + final_state_idx * stride_final_state_token
160160 p_ht = p_ht + i_hv * V * K + o_v [:, None ] * K + o_k [None , :]
161161 tl .store (p_ht , b_h .to (p_ht .dtype .element_ty ), mask = mask_h )
@@ -292,8 +292,7 @@ def fused_recurrent_gated_delta_rule_packed_decode_kernel(
292292 state_idx = tl .load (ssm_state_indices + i_n * stride_indices_seq ).to (tl .int64 )
293293 p_o = o + (i_n * HV + i_hv ) * V + o_v
294294
295- # Skip if state index is invalid (NULL_BLOCK_ID=0)
296- if state_idx <= 0 :
295+ if state_idx < 0 :
297296 zero = tl .zeros ([BV ], dtype = tl .float32 ).to (p_o .dtype .element_ty )
298297 tl .store (p_o , zero , mask = mask_v )
299298 return
0 commit comments