Skip to content

Commit 6b27cde

Browse files
chaojun-zhangeicherseiji
authored andcommitted
[XPU] support data parallel for MoE models on XPU (vllm-project#22887)
Signed-off-by: chzhang <[email protected]>
1 parent 3eb258d commit 6b27cde

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

vllm/distributed/device_communicators/xpu_communicator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
import torch.distributed as dist
88
from torch.distributed import ProcessGroup
99

10+
import vllm.envs as envs
11+
from vllm.logger import init_logger
12+
1013
from .base_device_communicator import DeviceCommunicatorBase
1114

15+
logger = init_logger(__name__)
16+
1217

1318
class XpuCommunicator(DeviceCommunicatorBase):
1419

@@ -18,6 +23,12 @@ def __init__(self,
1823
device_group: Optional[ProcessGroup] = None,
1924
unique_name: str = ""):
2025
super().__init__(cpu_group, device, device_group, unique_name)
26+
if self.use_all2all:
27+
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
28+
if all2all_backend == "naive":
29+
from .all2all import NaiveAll2AllManager
30+
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
31+
logger.info("Using naive all2all manager.")
2132

2233
def all_reduce(self, input_) -> torch.Tensor:
2334
dist.all_reduce(input_, group=self.device_group)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,8 @@ def forward_tpu(
655655
forward_native = forward_tpu
656656
elif current_platform.is_cpu():
657657
forward_native = forward_cpu
658+
elif current_platform.is_xpu():
659+
forward_native = forward_xpu
658660
else:
659661
forward_native = forward_cuda
660662

0 commit comments

Comments
 (0)