Skip to content

Conversation

UNIDY2002
Copy link
Contributor

In this PR, we propose Mooncake EP and the Mooncake Backend.

Mooncake EP is an adaptation of DeepEP that supports fault tolerance for large-scale MoE inference. It remains API-compatible with DeepEP, with an extra broken_ranks tensor to track failed ranks.

Mooncake Backend is a PyTorch distributed backend, designed as a fault-tolerant replacement for NCCL and Gloo. it can continue to perform collective communication under rank failures and reports them to upper layers for graceful handling.

Read more at doc/en/ep-backend.md.


Tests

Since the C++ APIs are not intended for direct use, no C++ unit tests are provided. Instead, three Python unit tests are included under mooncake-wheel/tests/:

  • test_mooncake_ep.py: Adapted from DeepEP’s test_low_latency.py. Verifies the correctness of the EP APIs and includes a basic performance test.
  • test_mooncake_backend.py: Validates the correctness of the Mooncake Backend.
  • test_mooncake_backend_perf.py: Compares the performance of the Mooncake Backend against NCCL and Gloo.

Performance

Tested on a 8 * H100 node.

Mooncake EP (pure RDMA)

Impl Dispatch bandwidth Dispatch latency Combine bandwidth Combine latency
Mooncake 41 GB/s 184 us 38 GB/s 387 us
DeepEP 46 GB/s 163 us 46 GB/s 318 us

Mooncake Backend

Here is the preliminary performance result of the Mooncake Backend. Further optimizations will be done in the future.

All data are in microseconds.

Mooncake v.s. Gloo

Allgather

Data Size Mooncake Gloo
1K 94 681
4K 125 834
16K 288 1121
64K 928 6253
256K 3715 8163
1M 7929 37067
4M 31239 142334

Allreduce

Data Size Mooncake Gloo
1K 87 1334
4K 163 1358
16K 476 1482
64K 1623 1606
256K 6382 2202
1M 23194 5324
4M 92664 15734

Broadcast

Data Size Mooncake Gloo
1K 61 101
4K 87 129
16K 142 177
64K 389 449
256K 1389 1130
1M 1662 2759
4M 7876 11559

Mooncake v.s. NCCL

Allgather

Data Size Mooncake NCCL
1K 67 93
4K 69 88
16K 78 93
64K 122 84
256K 293 81
1M 1038 178
4M 4158 521

Allreduce

Data Size Mooncake NCCL
1K 57 34
4K 60 30
16K 77 31
64K 122 30
256K 300 31
1M 1112 53
4M 14421 119

Broadcast

Data Size Mooncake NCCL
1K 50 28
4K 38 26
16K 47 27
64K 100 28
256K 246 34
1M 834 28
4M 3196 68

@UNIDY2002 UNIDY2002 force-pushed the sunxun/mooncake-backend-dev branch from 82b0d6c to f30e29d Compare September 15, 2025 03:08
Comment on lines +334 to +341
- name: Install CUDA Toolkit
uses: Jimver/[email protected]
with:
cuda: '12.8.1'
linux-local-args: '["--toolkit"]'
method: 'network'
sub-packages: '["nvcc", "nvrtc-dev"]'
non-cuda-sub-packages: '["libcusparse-dev", "libcublas-dev", "libcusolver-dev"]'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaguan do you have time to check on this? Do you know if this is supported on our CI machine?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/kvcache-ai/Mooncake/actions/runs/17720954259/job/50353039158?pr=805

It compiles successfully in CI, but I'm not sure if the .whl package will actually work for users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the users would usually have the full toolkit installed. I tested the .whl in the SGLang docker environment, and it could work :)

@whybeyoung
Copy link

Amazing work!

Comment on lines +112 to +148
if int(os.getenv("BUILD_WITH_EP", "0")):
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
abi_flag = int(torch._C._GLIBCXX_USE_CXX11_ABI)
current_dir = os.path.abspath(os.path.dirname(__file__))
ext_modules = [
CUDAExtension(
name="mooncake.ep",
include_dirs=[
os.path.join(current_dir, "../mooncake-ep/include"),
os.path.join(current_dir, "../mooncake-transfer-engine/include"),
],
sources=["../mooncake-integration/ep/ep_py.cpp"],
extra_compile_args={
"cxx": [f"-D_GLIBCXX_USE_CXX11_ABI={abi_flag}", "-std=c++20"],
"nvcc": [f"-D_GLIBCXX_USE_CXX11_ABI={abi_flag}", "-std=c++20"],
},
libraries=["ibverbs", "mlx5"],
extra_objects=[
os.path.join(current_dir, "../build/mooncake-ep/src/libmooncake_ep.a"),
os.path.join(current_dir, "mooncake/engine.so"),
],
)
]
setup(
distclass=BinaryDistribution,
cmdclass={
"bdist_wheel": CustomBdistWheel,
"build_ext": BuildExtension,
},
ext_modules=ext_modules,
)
else:
setup(
distclass=BinaryDistribution,
cmdclass={"bdist_wheel": CustomBdistWheel},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -std=c++20 the minimum required version? cc: @xiaguan

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mooncake Store needs C++20, others could probably use a lower C++ standard like C++17.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that a C++20 feature is used here (starts_with)

if (server_name.starts_with("[")) {

def dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, broken_ranks: torch.Tensor,
num_max_dispatch_tokens_per_rank: int, num_experts: int, timeout_us: int,
use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed Tuple[torch.Tensor, torch.Tensor] to Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]

@ShangmingCai
Copy link
Collaborator

I have another urgent PR need to test and review today, will continue with this PR tomorrow.

@alogfans Please take some time to review this PR as well.

TORCH_CHECK(tensorSize * meta->size < kBufferSize, "Too large!");
auto future = c10::make_intrusive<c10::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
int taskId = cpuTaskCount % 2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need a comment here for clarification?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment is added.

.attr("__version__")
.attr("split")("+")
.cast<std::vector<std::string>>()[0];
TORCH_CHECK(version == "2.8.0", "Mooncake Backend requires torch==2.8.0");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use >= in case SGLang/vLLM requires a newer version of PyTorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid a strict equal is required here, as the Mooncake lib should match the libtorch C++ ABI.

If SGLang/vLLM require a newer version of PyTorch, perhaps we have to recompile Mooncake with the corresponding PyTorch version. (Or, to be optimistic, we might figure out a better solution in the following versions.)

Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a huge PR. I have finished several rounds of basic reviews with some easy-to-fix problems. I think we can merge this first after addressing the above comments to see if we can get some user feedback. CC: @alogfans, better take a look before merging this PR.

@UNIDY2002
Copy link
Contributor Author

@ShangmingCai Thanks for your review and valuable feedbacks! I'll fix the issues.

@alogfans
Copy link
Collaborator

I agree with @ShangmingCai, merge it first.

@alogfans alogfans merged commit c5829aa into main Sep 26, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants