Skip to content

Commit 5841b54

Browse files
stas00sfc-gh-sbekmanjeffrasfc-gh-truwase
authored andcommitted
Ulysses SP for HF Integration (deepspeedai#7268)
This is the Deepspeed counterpart of snowflakedb/ArcticTraining#45 - as the new feature(s) require changes on both sides. For PR reviewers: Readiness status: - [x] Code - [x] Tests - [ ] Docs - working on it Features: - [x] add support for delaying grad addition via `param.ds_grad_is_ready` flag (used when performing tiled compute in an autograd function) - [x] add light sp-only mpu version (Jeff Rasley) - [x] improved debug - [x] added `all_gather_object` to `dist` - [x] `UlyssesSPAttentionHF` (port of UlyssesAttention from Megatron-Deepspeed plus modern MHA-variations) - [x] `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL batches to be used by `UlyssesSPAttentionHF` - [x] `SequenceTiledCompute` - generic autograd function to perform compute after tiling on the sequence dimension - [x] `TiledMLP` - a specific autograd function to perform tiled MLP (it's much easier to understand before trying to grok `SequenceTiledCompute`) - [x] added a differentiable `_DimZeroAllToAll` (Samyam Rajbhandari) - [x] torch-dist-check now allows `torch.distributed.nn` (which is needed since deepspeed's dist is not up to date with `torch.distributed.nn`) --------- Signed-off-by: Stas Bekman <[email protected]> Signed-off-by: Stas Bekman <[email protected]> Co-authored-by: Stas Bekman <[email protected]> Co-authored-by: Jeff Rasley <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 1e526a6 commit 5841b54

File tree

21 files changed

+1868
-32
lines changed

21 files changed

+1868
-32
lines changed

.github/workflows/nv-ds-chat.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ jobs:
4343
4444
- name: Install deepspeed
4545
run: |
46-
pip install transformers==4.48.3
4746
pip install .[dev]
47+
pip install transformers==4.48.3
4848
ds_report
4949
5050
- name: Install deepspeed-chat

deepspeed/comm/comm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ def all_gather(tensor_list,
242242
return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
243243

244244

245+
@timed_op
246+
def all_gather_object(object_list, obj, group=None, prof=False, log_name='all_gather_object', debug=get_caller_func()):
247+
global cdb
248+
return cdb.all_gather_object(object_list=object_list, obj=obj, group=group)
249+
250+
245251
def has_reduce_scatter_tensor():
246252
global cdb
247253
assert cdb is not None and cdb.is_initialized(

deepspeed/comm/torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ def all_gather_coalesced(self, output_tensors, input_tensors, group=None, async_
268268
else:
269269
reqs[-1].wait()
270270

271+
@disable_compiler_collective
272+
def all_gather_object(self, object_list, obj, group=None):
273+
return torch.distributed.all_gather_object(object_list=object_list, obj=obj, group=group)
274+
271275
@disable_compiler_collective
272276
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
273277
if self.has_reduce_scatter_tensor():

deepspeed/runtime/config.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,14 +721,23 @@ def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None):
721721
raise ValueError(
722722
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
723723
)
724+
724725
try:
725726
self.global_rank = dist.get_rank()
726727
if mpu is not None:
727-
self.world_size = mpu.get_data_parallel_world_size()
728+
# Ulysses SP
729+
if not hasattr(mpu, "get_data_parallel_world_size"):
730+
self.world_size = dist.get_world_size() / mpu.get_sequence_parallel_world_size()
731+
else:
732+
self.world_size = mpu.get_data_parallel_world_size()
728733
elif mesh_device is not None:
729734
self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
730735
else:
731-
self.world_size = dist.get_world_size()
736+
# HF zero.init case where there is no mpu
737+
if "sequence_parallel_size" in config:
738+
self.world_size = dist.get_world_size() / config["sequence_parallel_size"]
739+
else:
740+
self.world_size = dist.get_world_size()
732741
except:
733742
self.global_rank = 0
734743
self.world_size = 1
@@ -941,7 +950,7 @@ def _set_batch_related_parameters(self):
941950
micro_batch = self.train_micro_batch_size_per_gpu
942951
grad_acc = self.gradient_accumulation_steps
943952

944-
#print(f"train_batch = {train_batch}, micro_batch={micro_batch}")
953+
#print(f"in: train_batch = {train_batch}, micro_batch={micro_batch}")
945954

946955
# all values are provided nothing needs to be set
947956
if train_batch is not None and micro_batch is not None and grad_acc is not None:
@@ -980,6 +989,8 @@ def _set_batch_related_parameters(self):
980989
assert False, \
981990
'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided'
982991

992+
#print(f"final: {self.train_batch_size=} {self.train_micro_batch_size_per_gpu=} {self.gradient_accumulation_steps=}")
993+
983994
def _configure_train_batch_size(self):
984995
self._set_batch_related_parameters()
985996
self._batch_assertion()

deepspeed/runtime/engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,15 @@ def _configure_distributed_model(self, model):
13541354
self.communication_data_type = self._config.seq_parallel_communication_data_type
13551355
self.seq_parallel_group = groups._get_sequence_parallel_group()
13561356

1357+
if dist.get_rank() == 0:
1358+
summary = "********** distributed groups summary **********\n"
1359+
summary += f"\t {self.dp_world_size=}\n"
1360+
summary += f"\t {self.mp_world_size=}\n"
1361+
summary += f"\t {self.seq_dp_world_size=}\n"
1362+
summary += f"\t {self.sequence_parallel_size=}\n"
1363+
summary += "***********************************************"
1364+
logger.info(summary)
1365+
13571366
if not (self.amp_enabled() or is_zero_init_model):
13581367
self._broadcast_model()
13591368

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) The DeepSpeed Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) The DeepSpeed Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
"""
6+
This is a slimmed-down version of parallel_state.py (mpu) from Megatron-Deepspeed
7+
"""
8+
9+
from deepspeed import comm as dist
10+
11+
# Sequence parallel groups to handle both data and sequence parallelisms.
12+
# These groups are used to reduce gradients and shard parameters and optimizer stages for ZeRO.
13+
_SEQUENCE_PARALLEL_GROUP = None
14+
_SEQUENCE_DATA_PARALLEL_GROUP = None
15+
16+
17+
def initialize_sequence_parallel(sequence_parallel_size: int) -> None:
18+
"""Initialize sequence parallel groups."""
19+
20+
assert dist.is_initialized()
21+
world_size: int = dist.get_world_size()
22+
23+
if world_size < sequence_parallel_size:
24+
raise RuntimeError(f"world_size ({world_size}) is less than sequence_parallel_size {sequence_parallel_size}")
25+
26+
if sequence_parallel_size <= 1:
27+
raise ValueError(f"sequence_parallel_size must be greater than 1, got {sequence_parallel_size}")
28+
29+
if world_size % sequence_parallel_size != 0:
30+
raise RuntimeError(
31+
f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})")
32+
33+
data_parallel_size: int = world_size // sequence_parallel_size
34+
sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size
35+
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
36+
num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size
37+
38+
rank = dist.get_rank()
39+
40+
# Build the sequence parallel groups.
41+
global _SEQUENCE_PARALLEL_GROUP
42+
assert _SEQUENCE_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
43+
for i in range(num_sequence_parallel_groups):
44+
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
45+
group = dist.new_group(ranks)
46+
if rank in ranks:
47+
_SEQUENCE_PARALLEL_GROUP = group
48+
49+
# Build the sequence data parallel groups.
50+
global _SEQUENCE_DATA_PARALLEL_GROUP
51+
assert _SEQUENCE_DATA_PARALLEL_GROUP is None, "sequence data parallel group is already initialized"
52+
all_data_sequence_parallel_group_ranks = []
53+
for i in range(num_sequence_data_parallel_groups):
54+
ranks = range(i * sequence_data_parallel_size, (i + 1) * sequence_data_parallel_size)
55+
group = dist.new_group(ranks)
56+
all_data_sequence_parallel_group_ranks.append(list(ranks))
57+
if rank in ranks:
58+
_SEQUENCE_DATA_PARALLEL_GROUP = group
59+
60+
61+
def get_sequence_parallel_group():
62+
"""Get the sequence parallel group the caller rank belongs to."""
63+
assert _SEQUENCE_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
64+
return _SEQUENCE_PARALLEL_GROUP
65+
66+
67+
def get_sequence_data_parallel_group():
68+
"""Get the sequence parallel group the caller rank belongs to."""
69+
assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, "sequence data parallel group is not initialized"
70+
return _SEQUENCE_DATA_PARALLEL_GROUP
71+
72+
73+
def get_sequence_parallel_world_size():
74+
"""Return world size for the sequence parallel group."""
75+
return dist.get_world_size(group=get_sequence_parallel_group())
76+
77+
78+
def get_sequence_data_parallel_world_size():
79+
"""Return world size for the sequence parallel group."""
80+
return dist.get_world_size(group=get_sequence_data_parallel_group())
81+
82+
83+
def get_sequence_parallel_rank():
84+
"""Return my rank for the sequence parallel group."""
85+
return dist.get_rank(group=get_sequence_parallel_group())
86+
87+
88+
def get_sequence_data_parallel_rank():
89+
"""Return my rank for the sequence data parallel group."""
90+
return dist.get_rank(group=get_sequence_data_parallel_group())

0 commit comments

Comments
 (0)