-
-
Notifications
You must be signed in to change notification settings - Fork 10.2k
Description
Your current environment
We are working on accelerating RLHF algorithms and need to broadcast the weights of the DeepSpeed engine to the vLLM Ray worker. In v0.4.2, we were able to create an additional NCCL group to achieve this. However, after updating to v0.4.3 and incorporating the changes from this MR, we found that doing so causes NCCL errors during broadcast.
Our weight synchronization code is located at: https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_engine.py.
and
https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_worker_wrap.py
see init_process_group
(build NCCL group between vLLM and DeepSpeed named self._model_update_group
)
and update_weight
(Broadcast weights from DeepSpeed to vLLM, torch.distributed.broadcast(weight, 0, group=self._model_update_group)
)
We temporarily replaced the NCCL backend with GLOO to make it work, but the performance was poor。
The error message is:
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] Error executing method start_worker_execution_loop. This might cause deadlock in distributed execution.
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] Traceback (most recent call last):
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 140, in execute_method
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] return executor(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] return func(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/worker/worker.py", line 286, in start_worker_execution_loop
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] while self._execute_model_non_driver():
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/worker/worker.py", line 295, in _execute_model_non_driver
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] data = broadcast_tensor_dict(src=0)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/distributed/communication_op.py", line 284, in broadcast_tensor_dict
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] torch.distributed.broadcast_object_list(recv_metadata_list,
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] return func(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2649, in broadcast_object_list
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] broadcast(object_sizes_tensor, src=src, group=group)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] return func(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2144, in broadcast
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] work.wait()
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] RuntimeError: [../third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:81] Timed out waiting 1800000ms for recv operation to complete
Even call self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()
before broadcast, there will still be one other NCCL error.
(LLMRayActor pid=116814) /12 : 0 1
(LLMRayActor pid=116814) a5
(LLMRayActor pid=116812) Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] -1/-1/-1->0->1 [4] -1/-1/-1->0->1 [5] -1/-1/-1->0->1 [6] 1/-1/-1->0->-1 [7] 1/-1/-1->0->-1 [8] 1/-1/-1->0->-1 [9] -1/-1/-1->0->1 [10] -1/-1/-1->0->1 [11] -1/-1/-1->0->1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120170 [0] proxy.cc:1336 NCCL WARN Cuda failure 1 'invalid argument'
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] transport/p2p.cc:272 NCCL WARN Cuda failure 'invalid argument'
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL INFO transport/p2p.cc:327 -> 1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL INFO transport/p2p.cc:507 -> 1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL INFO transport.cc:183 -> 1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL IERROR 06-13 13:24:49 worker_base.py:148] Error executing method update_weight. This might cause deadlock in distributed execution.
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] Traceback (most recent call last):
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] File "/home/jianh/.local/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 140, in execute_method
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] return executor(*args, **kwargs)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] File "/tmp/ray/session_2024-06-13_13-16-35_468561_107280/runtime_resources/working_dir_files/_ray_pkg_d1835c417c453aec/openrlhf/trainer/ray/vllm_worker_wrap.py", line 39, in update_weight
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] torch.distributed.broadcast(weight, 0, group=self._model_update_group)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] return func(*args, **kwargs)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 2140, in broadcast
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] work = group.broadcast([tensor], opts)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1970, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.20.5
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] ncclUnhandledCudaError: Call to CUDA function failed.
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] Last error:
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] Cuda failure 'invalid argument'
(RayWorkerWrapper pid=117813) a5fa65866c9c:117813:120165 [1] NCCL INF
(RayWorkerWrapper pid=117839) ERROR 06-13 13:24:49 worker_base.py:148] Error executing method update_weight. This might cause deadlock in distributed execution.
I think our code torch.distributed.broadcast(weight, 0, group=self._model_update_group)
may be conflicts with this this MR. btw, I'm not sure how to fix it.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status