Skip to content

[Bugfix] add missing function params to rocm_aiter_mla.py #17911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,9 +1213,9 @@ def _compute_prefill_context(

attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
q,
k,
v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
Expand Down Expand Up @@ -1267,9 +1267,9 @@ def _forward_prefill(
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
q,
k,
v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
Expand Down
56 changes: 42 additions & 14 deletions vllm/attention/backends/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:

@dataclass
class AiterMLAMetadata(MLACommonMetadata):
# The following 4 tensors are for current version of AITER MLA
# The following 5 tensors are for current version of AITER MLA
block_table_bound: Optional[torch.Tensor] = None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
Expand All @@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens: Optional[torch.Tensor] = None

# This is just to make new AITER MLA API work
# -- MTP support is not added yet.
qo_indptr: Optional[torch.Tensor] = None

@property
def prefill_metadata(self):
prefill_metadata = super().prefill_metadata
Expand All @@ -74,6 +78,7 @@ def prefill_metadata(self):
prefill_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
prefill_metadata.block_table_bound = self.block_table_bound
prefill_metadata.qo_indptr = self.qo_indptr

# update the cache
self._cached_prefill_metadata = self.__class__(
Expand All @@ -93,6 +98,7 @@ def decode_metadata(self):
decode_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
decode_metadata.block_table_bound = self.block_table_bound
decode_metadata.qo_indptr = self.qo_indptr

# update the cache
self._cached_decode_metadata = self.__class__(
Expand Down Expand Up @@ -136,6 +142,7 @@ def prepare(self):
self.paged_kv_indptr: list[int] = [0]
self.paged_kv_last_page_lens: list[int] = []
self.total_blocks = 0
self.qo_indptr: list[int] = [0]

def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
prefix_cache_hit: bool):
Expand Down Expand Up @@ -208,6 +215,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
self.qo_indptr.append(self.qo_indptr[-1] + 1)

last_page_len = seq_len % self.block_size
if last_page_len == 0:
Expand All @@ -226,6 +234,8 @@ def build(self, seq_lens: list[int], query_lens: list[int],
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size)
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
last_qo_indptr = self.qo_indptr[-1]
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)

# For current version of AITER MLA
if len(self.paged_kv_indptr) > 0:
Expand All @@ -245,16 +255,22 @@ def build(self, seq_lens: list[int], query_lens: list[int],
1,
device=device,
dtype=torch.int)

qo_indptr = torch.tensor(self.qo_indptr,
device=device,
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_lens_tensor = None
block_table_bound_tensor = None
qo_indptr = None

metadata.paged_kv_indptr = paged_kv_indptr_tensor
metadata.paged_kv_indices = paged_kv_indices_tensor
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
metadata.block_table_bound = block_table_bound_tensor
metadata.qo_indptr = qo_indptr

return metadata

Expand All @@ -263,21 +279,25 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):

@contextmanager
def graph_capture(self, max_batch_size: int):
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
max_batch_size=max_batch_size,
block_size=self.runner.block_size,
max_block_per_batch=self.runner.get_max_block_per_batch(),
device=self.runner.device)
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
get_aiter_mla_metadata(
max_batch_size=max_batch_size,
block_size=self.runner.block_size,
max_block_per_batch=\
self.runner.get_max_block_per_batch(),
device=self.runner.device)
self._paged_kv_indices_tensor = kv_indices
self._paged_kv_indptr_tensor = kv_indptr
self._paged_kv_last_page_lens_tensor = last_page_lens
self._qo_indptr_tensor = qo_indptr

with super().graph_capture(max_batch_size):
yield

del self._paged_kv_indices_tensor
del self._paged_kv_indptr_tensor
del self._paged_kv_last_page_lens_tensor
del self._qo_indptr_tensor

def graph_capture_get_metadata_for_batch(
self,
Expand All @@ -291,10 +311,12 @@ def graph_capture_get_metadata_for_batch(
paged_kv_indices = self._paged_kv_indices_tensor
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
batch_size]
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]

metadata.paged_kv_indptr = paged_kv_indptr
metadata.paged_kv_indices = paged_kv_indices
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
metadata.qo_indptr = qo_indptr

return metadata

Expand All @@ -311,6 +333,7 @@ def get_graph_input_buffers(self,
input_buffers[
"paged_kv_last_page_lens"] = attn_metadata.\
decode_metadata.paged_kv_last_page_lens
input_buffers['qo_indptr'] = attn_metadata.qo_indptr

return input_buffers

Expand All @@ -330,6 +353,8 @@ def prepare_graph_input_buffers(self,
input_buffers["paged_kv_last_page_lens"].copy_(
attn_metadata.decode_metadata.paged_kv_last_page_lens,
non_blocking=True)
input_buffers["qo_indptr"].copy_(
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)


class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
Expand Down Expand Up @@ -370,11 +395,9 @@ def _flash_attn_varlen_diff_headdims(
softmax_scale: float, return_softmax_lse: bool,
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
q,
k,
v,
**kwargs,
)

Expand All @@ -394,17 +417,22 @@ def _forward_decode(
B = q_nope.shape[0]

q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
o = torch.empty(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)

kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
aiter_mla_decode_fwd(q,
kv_buffer,
o,
attn_metadata.qo_indptr,
attn_metadata.max_query_len,
Comment on lines +431 to +432
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a type check warning because these can be None, but the function doesn't allow None. Do we add more asserts above to ensure they're not None?

attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_lens)
attn_metadata.paged_kv_last_page_lens,
sm_scale=self.scale)

return self._v_up_proj(o)
15 changes: 13 additions & 2 deletions vllm/attention/ops/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,29 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
paged_kv_last_page_lens = torch.full((max_batch_size, ),
block_size,
dtype=torch.int32)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr


def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
sm_scale: float = 1.0,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved down and added default value for consistency

logit_cap: float = 0.0,
):

torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
kv_buffer.view(
-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
Expand All @@ -49,6 +54,8 @@ def mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
Expand All @@ -60,9 +67,11 @@ def mla_decode_fwd_impl(
mla_decode_fwd(q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap)

Expand All @@ -71,6 +80,8 @@ def mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(

fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
sorted_weight_buf, sorted_expert_ids,
num_valid_ids, topk, w1_scale.view(local_E, -1),
w2_scale.view(local_E, -1),
a1_scale.t().contiguous(), *block_shape,
smooth_scale)
num_valid_ids, topk,
a1_scale.t().contiguous(),
w1_scale.view(local_E, -1),
w2_scale.view(local_E,
-1), *block_shape, smooth_scale)

return out_asm

Expand Down
9 changes: 7 additions & 2 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,14 @@

kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)

aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
aiter_mla_decode_fwd(q,
kv_buffer,
o,
attn_metadata.qo_indptr,

Check failure on line 192 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "qo_indptr" [attr-defined]

Check failure on line 192 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "qo_indptr" [attr-defined]

Check failure on line 192 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "qo_indptr" [attr-defined]

Check failure on line 192 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "qo_indptr" [attr-defined]

Check failure on line 192 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "qo_indptr" [attr-defined]
attn_metadata.max_query_len,

Check failure on line 193 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "max_query_len" [attr-defined]

Check failure on line 193 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "max_query_len" [attr-defined]

Check failure on line 193 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "max_query_len" [attr-defined]

Check failure on line 193 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "max_query_len" [attr-defined]

Check failure on line 193 in vllm/v1/attention/backends/mla/rocm_aiter_mla.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AiterMLAMetadata" has no attribute "max_query_len" [attr-defined]
Comment on lines +192 to +193
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we create these attributes?

attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
attn_metadata.decode.paged_kv_last_page_len,
sm_scale=self.scale)

return self._v_up_proj(o)