|
| 1 | +# Copyright 2025 Tencent Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +"""Backend dispatcher: routes Stem prefill to the correct implementation.""" |
| 17 | + |
| 18 | +from __future__ import annotations |
| 19 | + |
| 20 | +import torch |
| 21 | + |
| 22 | +from .torch_impl import stem_forward_torch |
| 23 | + |
| 24 | + |
| 25 | +def stem_forward( |
| 26 | + query_states: torch.Tensor, |
| 27 | + key_states: torch.Tensor, |
| 28 | + value_states: torch.Tensor, |
| 29 | + prefill_kwargs: dict, |
| 30 | +) -> torch.Tensor: |
| 31 | + """Dispatch a Stem prefill call to the appropriate backend. |
| 32 | +
|
| 33 | + Args: |
| 34 | + query_states: Query tensor of shape ``(B, H_q, L_q, D)``. |
| 35 | + key_states: Key tensor of shape ``(B, H_kv, L_kv, D)``. |
| 36 | + value_states: Value tensor of shape ``(B, H_kv, L_kv, D)``. |
| 37 | + prefill_kwargs: Must contain ``"attn_forward_config"`` (with a |
| 38 | + ``"backend"`` key) and ``"layer_idx"``. |
| 39 | +
|
| 40 | + Returns: |
| 41 | + Attention output tensor of shape ``(B, H_q, L_q, D)``. |
| 42 | +
|
| 43 | + Raises: |
| 44 | + ValueError: If the requested backend is not ``"torch"`` or ``"hpc"``. |
| 45 | + """ |
| 46 | + config = prefill_kwargs["attn_forward_config"] |
| 47 | + backend = config.get("backend", "torch") |
| 48 | + |
| 49 | + if backend == "torch": |
| 50 | + return stem_forward_torch(query_states, key_states, value_states, prefill_kwargs) |
| 51 | + |
| 52 | + if backend == "hpc": |
| 53 | + # Lazy import to avoid hard dependency on the ``hpc`` C++ extension |
| 54 | + # when only the pure-torch path is needed. |
| 55 | + from .hpc_impl import stem_forward_hpc |
| 56 | + |
| 57 | + return stem_forward_hpc(query_states, key_states, value_states, prefill_kwargs) |
| 58 | + |
| 59 | + raise ValueError(f"Unknown stem backend: {backend!r}") |
0 commit comments