Skip to content

Commit e8ed718

Browse files
committed
[Bugfix] add missing function params to rocm_aiter_mla.py
Looks like #17864 had an outdated branch. So its [merge commit][1] caused `qo_indptr` and `max_seqlen_qo` to go into the function signature of `aiter_mla_decode_fwd()` where they're not used and into the body of `mla_decode_fwd_impl()` where they aren't defined. This PR fixes the discrepancies and call-sites. Signed-off-by: David Xia <[email protected]> [1]: 9f64e93#diff-88fd09f50e8cfc77678ade87483ab9a89ce58904203578f8816882763bd577c2
1 parent 9f64e93 commit e8ed718

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

vllm/attention/backends/rocm_aiter_mla.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,14 @@ def _forward_decode(
425425

426426
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
427427

428-
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
428+
aiter_mla_decode_fwd(q,
429+
kv_buffer,
430+
o,
429431
attn_metadata.qo_indptr,
430432
attn_metadata.max_query_len,
431433
attn_metadata.paged_kv_indptr,
432434
attn_metadata.paged_kv_indices,
433-
attn_metadata.paged_kv_last_page_lens)
435+
attn_metadata.paged_kv_last_page_lens,
436+
sm_scale=self.scale)
434437

435438
return self._v_up_proj(o)

vllm/attention/ops/rocm_aiter_mla.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,21 @@ def aiter_mla_decode_fwd(
2828
q: torch.Tensor,
2929
kv_buffer: torch.Tensor,
3030
o: torch.Tensor,
31-
sm_scale: float,
3231
qo_indptr: torch.Tensor,
3332
max_seqlen_qo: int,
3433
kv_indptr: Optional[torch.Tensor] = None,
3534
kv_indices: Optional[torch.Tensor] = None,
3635
kv_last_page_lens: Optional[torch.Tensor] = None,
36+
sm_scale: float = 1.0,
3737
logit_cap: float = 0.0,
3838
):
3939

4040
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
4141
kv_buffer.view(
4242
-1, 1, 1, q.shape[-1]),
4343
o,
44+
qo_indptr,
45+
max_seqlen_qo,
4446
kv_indptr,
4547
kv_indices,
4648
kv_last_page_lens,
@@ -52,6 +54,8 @@ def mla_decode_fwd_impl(
5254
q: torch.Tensor,
5355
kv_buffer: torch.Tensor,
5456
o: torch.Tensor,
57+
qo_indptr: torch.Tensor,
58+
max_seqlen_qo: int,
5559
kv_indptr: Optional[torch.Tensor] = None,
5660
kv_indices: Optional[torch.Tensor] = None,
5761
kv_last_page_lens: Optional[torch.Tensor] = None,
@@ -76,6 +80,8 @@ def mla_decode_fwd_fake(
7680
q: torch.Tensor,
7781
kv_buffer: torch.Tensor,
7882
o: torch.Tensor,
83+
qo_indptr: torch.Tensor,
84+
max_seqlen_qo: int,
7985
kv_indptr: Optional[torch.Tensor] = None,
8086
kv_indices: Optional[torch.Tensor] = None,
8187
kv_last_page_lens: Optional[torch.Tensor] = None,

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,14 @@ def _forward_decode(
186186

187187
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
188188

189-
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
189+
aiter_mla_decode_fwd(q,
190+
kv_buffer,
191+
o,
192+
attn_metadata.qo_indptr,
193+
attn_metadata.max_query_len,
190194
attn_metadata.decode.paged_kv_indptr,
191195
attn_metadata.decode.paged_kv_indices,
192-
attn_metadata.decode.paged_kv_last_page_len)
196+
attn_metadata.decode.paged_kv_last_page_len,
197+
sm_scale=self.scale)
193198

194199
return self._v_up_proj(o)

0 commit comments

Comments
 (0)