Skip to content

Commit 3a9c1a3

Browse files
lolotoxlucienluo
andauthored
feat: add Stem sparse attention module (#286)
Co-authored-by: xlucienluo <[email protected]>
1 parent 90209c2 commit 3a9c1a3

20 files changed

Lines changed: 2169 additions & 5 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ A more accessible, comprehensive, and efficient toolkit for large model compress
9494
<li>
9595
<strong>Sparse Attention</strong>
9696
<ul style="padding-left: 1.5rem">
97-
<li>Under Development</li>
97+
<li><a href="https://angelslim.readthedocs.io/zh-cn/latest/features/sparse_attention/stem.html">Stem</a></li>
9898
</ul>
9999
</li>
100100
</ul>

README_cn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
<li>
9696
<strong>稀疏注意力</strong>
9797
<ul style="padding-left: 1.5rem">
98-
<li>Minference(建设中)</li>
98+
<li><a href="https://angelslim.readthedocs.io/zh-cn/latest/features/sparse_attention/stem.html">Stem</a></li>
9999
</ul>
100100
</li>
101101
</ul>

angelslim/compressor/sparsity/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from .stem import StemInference # noqa: F401
16+
17+
__all__ = ["StemInference"]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Stem — Sparse Token Estimation Module for long-context LLM inference.
17+
18+
Public API:
19+
StemInference: Callable that patches a HuggingFace model to use Stem
20+
sparse attention during the prefill stage.
21+
"""
22+
23+
from .stem import StemInference
24+
25+
__all__ = ["StemInference"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Stem backend implementations (torch / HPC)."""
17+
18+
from .dispatcher import stem_forward
19+
20+
__all__ = ["stem_forward"]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Tencent Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
"""Backend dispatcher: routes Stem prefill to the correct implementation."""
17+
18+
from __future__ import annotations
19+
20+
import torch
21+
22+
from .torch_impl import stem_forward_torch
23+
24+
25+
def stem_forward(
26+
query_states: torch.Tensor,
27+
key_states: torch.Tensor,
28+
value_states: torch.Tensor,
29+
prefill_kwargs: dict,
30+
) -> torch.Tensor:
31+
"""Dispatch a Stem prefill call to the appropriate backend.
32+
33+
Args:
34+
query_states: Query tensor of shape ``(B, H_q, L_q, D)``.
35+
key_states: Key tensor of shape ``(B, H_kv, L_kv, D)``.
36+
value_states: Value tensor of shape ``(B, H_kv, L_kv, D)``.
37+
prefill_kwargs: Must contain ``"attn_forward_config"`` (with a
38+
``"backend"`` key) and ``"layer_idx"``.
39+
40+
Returns:
41+
Attention output tensor of shape ``(B, H_q, L_q, D)``.
42+
43+
Raises:
44+
ValueError: If the requested backend is not ``"torch"`` or ``"hpc"``.
45+
"""
46+
config = prefill_kwargs["attn_forward_config"]
47+
backend = config.get("backend", "torch")
48+
49+
if backend == "torch":
50+
return stem_forward_torch(query_states, key_states, value_states, prefill_kwargs)
51+
52+
if backend == "hpc":
53+
# Lazy import to avoid hard dependency on the ``hpc`` C++ extension
54+
# when only the pure-torch path is needed.
55+
from .hpc_impl import stem_forward_hpc
56+
57+
return stem_forward_hpc(query_states, key_states, value_states, prefill_kwargs)
58+
59+
raise ValueError(f"Unknown stem backend: {backend!r}")

0 commit comments

Comments
 (0)