-
Notifications
You must be signed in to change notification settings - Fork 459
[WIP][Feature] mooncake connector support GQA transport #2947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP][Feature] mooncake connector support GQA transport #2947
Conversation
Signed-off-by: zzy-ContiLearn <[email protected]>
Signed-off-by: zzy-ContiLearn <[email protected]>
Signed-off-by: zzy-ContiLearn <[email protected]>
Signed-off-by: zzy-ContiLearn <[email protected]>
Signed-off-by: zzy-ContiLearn <[email protected]>
…g errors caused by .npu()
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for QGA transport in the mooncake connector, which involves significant changes to handle chunked KV cache transfers. The changes include modifications to the sending and receiving threads, metadata handling, and the logic for selecting tensor parallelism ranks for data transfer. While the overall direction seems correct for enabling more flexible data transport, I've found a critical issue in the KV cache transfer logic where source and destination addresses appear to be swapped, and the remote address calculation for chunked transfers seems incorrect. This could lead to incorrect data transfer or corruption. Addressing this is crucial for the feature to work correctly.
src_list, dst_list, length_list = [], [], [] | ||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( | ||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): | ||
block_len = (self.block_len[k % 2] | ||
if self.use_mla else self.block_len[0]) | ||
for i, remote_block_id in enumerate(grouped_remote_block_ids): | ||
local_block_ids = grouped_local_block_ids[i] | ||
src = src_layer_base_addr + local_block_ids[0] * block_len | ||
dst = dst_layer_base_addr + remote_block_id[0] * block_len | ||
length = len(local_block_ids) * block_len | ||
block_len = self.block_len[k % 2] | ||
inner_block_len = block_len // self.num_need_pulls | ||
for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids): | ||
src = src_layer_base_addr + local_block_id[0] * block_len + offset * inner_block_len | ||
dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len | ||
length = inner_block_len * len(local_block_id) | ||
src_list.append(src) | ||
dst_list.append(dst) | ||
length_list.append(length) | ||
ret = self.engine.batch_transfer_sync_read(session_id, src_list, | ||
dst_list, length_list) | ||
|
||
ret = self.engine.batch_transfer_sync_read(session_id, src_list, dst_list, | ||
length_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a critical issue with the batch_transfer_sync_read
call and the address calculations. The KVCacheRecvingThread
should be reading from a remote source to a local destination, but the source and destination addresses seem to be swapped.
- Swapped Source/Destination:
src_list
is populated with local addresses anddst_list
with remote addresses. Sincebatch_transfer_sync_read
is a read operation, the source should be remote and the destination local. The logic for calculating addresses forsrc
anddst
seems to be swapped. - Incorrect Remote Address Calculation: The remote address calculation (assigned to
dst
in the current code) does not use theoffset
parameter. This will cause it to read the same chunk from the remote source repeatedly for multi-chunk transfers. It should likely incorporate bothblock_len
andoffset
.
Here is a suggested correction that assumes batch_transfer_sync_read
takes (session, remote_src_addrs, local_dst_addrs, lengths)
and fixes the address calculations. The confusing variable names (src_layer_base_addr
for local, dst_layer_base_addr
for remote) are renamed for clarity within the suggestion.
src_list, dst_list, length_list = [], [], []
for k, (local_layer_base_addr, remote_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
block_len = self.block_len[k % 2]
inner_block_len = block_len // self.num_need_pulls
for remote_block_id, local_block_id in zip(grouped_remote_block_ids, grouped_local_block_ids):
# remote source address
src = remote_layer_base_addr + remote_block_id[0] * block_len + offset * inner_block_len
# local destination address
dst = local_layer_base_addr + local_block_id[0] * block_len + offset * inner_block_len
length = inner_block_len * len(local_block_id)
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
ret = self.engine.batch_transfer_sync_read(session_id, src_list, dst_list,
length_list)
Signed-off-by: zzy-ContiLearn <[email protected]>
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
6c0d05d
to
f002afb
Compare
…ache eviction in the Prefill node. Signed-off-by: chenxiao <[email protected]>
f002afb
to
82db3b5
Compare
What this PR does / why we need it?
TODO:
Does this PR introduce any user-facing change?
How was this patch tested?