-
Notifications
You must be signed in to change notification settings - Fork 525
Use new DeviceMesh unflatten to rewrite parallel_dims #1660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
565a5f6
5fa85d9
234f80e
baaa3ea
3f4181e
70be316
3716135
a6078b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
from collections import defaultdict | ||||||
from dataclasses import dataclass | ||||||
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | ||||||
|
@@ -25,6 +26,7 @@ class ParallelDims: | |||||
ep: int | ||||||
etp: int | ||||||
world_size: int | ||||||
mesh_dim_names: tuple[str] = tuple() | ||||||
|
||||||
_world_mesh: DeviceMesh = None | ||||||
|
||||||
|
@@ -63,6 +65,139 @@ def _validate(self): | |||||
# EP would borrow all cp and tp and some dp_shard degree | ||||||
assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 | ||||||
|
||||||
def build_mesh(self) -> "ParallelDims": | ||||||
"""Build the device mesh with the required mesh dimensions. | ||||||
|
||||||
The following mesh dimensions may be created based on the parallel configuration: | ||||||
|
||||||
pp: For PP. | ||||||
dp_replicate: For DDP or HSDP replicate dimension. | ||||||
dp_shard_cp: For FSDP or HSDP shard dimension. This includes | ||||||
``cp`` even if ``cp`` is 1. As a result, we always | ||||||
use the name ``dp_shard_cp``, and ``dp_shard`` is not | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it harmful to just create dp_shard anyway, for symmetry? Are you trying to stop people from accidentally using the wrong mesh dim axis because they weren't thinking about CP? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ezyang There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py#L101
I think without lazy init of PG, this |
||||||
created as a dimension. | ||||||
dp_cp: This is used by loss all-reduce. It includes ``dp_replicate``, | ||||||
``dp_shard``, and ``cp`` as all of them are data parallelisms. | ||||||
dp: This is used by data loading to decide the global batch size and | ||||||
which part of data this raunk should read. This dim includes both | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
``dp_replicate`` and ``dp_shard``. | ||||||
The name is confusing; ``batch`` could be a better name. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this I agree too! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're going to pick a JAX-y name "batch" for dp, why don't we just extend this principle all the mesh dims? The rule is you name the dimension after the most important thing that will be sharded by it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ezyang There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, I'd prefer the atomic mesh dim names to be aligned with parallelism, and the flattened dim to the actual usage, instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea this framing makes sense to me. |
||||||
cp: For CP. | ||||||
tp: For TP. | ||||||
ep: For EP. | ||||||
dp_shard_in_ep: For FSDP or HSDP shard dimension in the EP region. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to use conditions like
because otherwise it's too heavy code.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not exactly the same. |
||||||
|
||||||
Note: These dimensions won't exist at the same time. If we consider | ||||||
the unflatten() operator only, the following are all the meshes required | ||||||
assuming all degrees are > 1 except for ``pp``: | ||||||
|
||||||
["dp", "cp", "tp"]: The ``dp`` process group is wasted as the dataloader | ||||||
doesn't need it for communication. | ||||||
["dp_cp", "tp"]: Loss computation. | ||||||
["dp_replicate", "dp_shard_cp", "tp"]: Non-EP region computation. | ||||||
["dp_replicate", "dp_shard_in_ep", "ep", "tp"]: EP region computation if etp == tp. | ||||||
["dp_replicate", "dp_shard_in_ep", "ep"]: EP region computation if etp == 1. | ||||||
|
||||||
In reality, we don't actually need to create all of these meshes. | ||||||
For example, ``dp_cp`` can be sliced and flattened from ["dp", "cp", "tp"]. | ||||||
So we don't actually need to create ["dp_cp", "tp"]. | ||||||
|
||||||
But there are some meshes we MUST create if that mesh will be used for a | ||||||
parameter. So Non-EP-region-computation mesh and EP-region-computation mesh | ||||||
are required. | ||||||
""" | ||||||
|
||||||
def add_dim(name, degree, config): | ||||||
config["name"].append(name) | ||||||
config["degree"].append(degree) | ||||||
|
||||||
world_mesh = init_device_mesh(device_type, [self.world_size]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one thought is that we might want to define the semantics of device-mesh creation the following way. for the purpose of this, I am pretending this case would initialize a world group and then split subgroups off of it
this case would not initialize the world group, it would initialize separate groups a, b, c I bring this up because at large scale, it is expensive to initialize the world group, so it is good to let users choose what they actually want to happen There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR actually follow this proposal. It's just that we are using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It depends on the benefit of |
||||||
dp_shard_in_ep = ( | ||||||
self.dp_shard * self.cp // self.ep | ||||||
if self.etp == self.tp | ||||||
else self.dp_shard * self.cp * self.tp // self.ep | ||||||
) | ||||||
Comment on lines
+115
to
+119
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ideally we shouldn't do any "real math" in this file, see comment above |
||||||
|
||||||
data_mesh_dims = defaultdict(list) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking we just need two (dense and sparse), any specific reason we have to have three? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one can be unflattened from the two meshes you mentioned. But either way, we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After second thought, I'm not sure if we actually want to just use two meshes and others are derived by these two. The reason is that if we are trying to create a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that Also when I say two global meshes, I mean the dense part to be So I'd view There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are going to do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You can do
Yes, I also prefer using More precisely, I'm thinking about what the semantics are for combining two device meshes. For example, if the original DeviceMesh is We currently circumvent this problem by getting the root mesh and then inferring the SPMD mesh. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh got it. I think one proposal is to create every required flattened mesh together with the atomic meshes during DeviceMesh init, so theoretically there's no need to call flatten/unflatten. I personally think it sounds cleaner.
I think this should be a DeviceMesh util. E.g. can we use the composition from CuTe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's actually not composition, it should be concatenation. In @fegin's example, the full layout is (2, 2):(2, 1), the TP mesh is 2:1 and the FSDP mesh is 2:2 . You concatenate them together (FSDP + TP) and you get back the full layout. Also, while we print the meshes based on the current rank coordinate, IMO there's a sense in which the DeviceMesh encodes "all the ranks", and we "pick" out the correct coordinate very late, e.g., when split_group is called. So it's not really true that you're missing 4, because the global_ranks concept in @fduwjj's PR is actually returning something like [[1, 2], [3, 4]] or [[1, 3], [2, 4]] which is passed as the splits for the PG. Unrestricted concatenation can result in some goofy layouts though. For example, if you concatenate TP with itself you end up with (2, 2):(1, 1) aka [[1, 2], [2, 3]], which I'm pretty sure there's no sensible concept for. So you would want to restrict concatenation to be disjoint. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I see. My comment may only be true for the original DeviceMesh implementation, which I'm pretty sure people cannot concatenate two "small" meshes to a "big" mesh, even though semantic-wise DeviceMesh should behave like what @ezyang says -- encodes "all the ranks". If the new implementation, (which I haven't fully understood why it can do concatenation yet lol), can support concatenation, then SPMD mesh shouldn't be an issue even without This is probably going to solve the "PP has to be the outmost dimension" issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With CuTe layout, it is natural to extend it to support concatenate and this might also be needed within the implementation of unflatten. And yes it is not composition, we need to implement it in a separate PR. The global ranks returned needs to be mapped with the signature of |
||||||
non_ep_computation_dims = defaultdict(list) | ||||||
ep_computation_dims = defaultdict(list) | ||||||
Comment on lines
+122
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
|
||||||
if self.pp_enabled: | ||||||
add_dim("pp", self.pp, data_mesh_dims) | ||||||
add_dim("pp", self.pp, non_ep_computation_dims) | ||||||
add_dim("pp", self.pp, ep_computation_dims) | ||||||
|
||||||
if self.dp_enabled: | ||||||
add_dim("dp", self.dp_replicate * self.dp_shard, data_mesh_dims) | ||||||
if self.dp_replicate_enabled: | ||||||
add_dim("dp_replicate", self.dp_replicate, non_ep_computation_dims) | ||||||
add_dim("dp_replicate", self.dp_replicate, ep_computation_dims) | ||||||
if self.dp_shard_enabled: | ||||||
add_dim("dp_shard_cp", self.dp_shard * self.cp, non_ep_computation_dims) | ||||||
add_dim("dp_shard_in_ep", dp_shard_in_ep, ep_computation_dims) | ||||||
|
||||||
if self.cp_enabled: | ||||||
add_dim("cp", self.cp, data_mesh_dims) | ||||||
|
||||||
if self.tp_enabled: | ||||||
add_dim("tp", self.tp, data_mesh_dims, non_ep_computation_dims) | ||||||
if self.etp == self.tp: | ||||||
add_dim("tp", self.tp, ep_computation_dims) | ||||||
|
||||||
self._all_meshes = [] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, storing a list of all_meshes and then having to index into it feels like a UX regression. i wonder if we need to assemble the meshes into such a list at all? should we just store There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UX regression to who? I think it is a UX regression to TorchTitan As for people who uses There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my comment was more focused on the choise to put all_meshes into an ordered list and access the recent one by [-1]. I would maybe either store them in a dict with descriptive names, or, just store them directly on ParallelDims as optional attributes. |
||||||
|
||||||
if self.dp_enabled: | ||||||
data_mesh = world_mesh._unflatten( | ||||||
0, data_mesh_dims["degree"], data_mesh_dims["name"] | ||||||
) | ||||||
self._all_meshes.append(data_mesh) | ||||||
# Note that we don't create loss_mesh as it is easier to flatten | ||||||
# from data_mesh | ||||||
if self.cp_enabled: | ||||||
self._all_meshes[-1]["dp", "cp"]._flatten(mesh_dim_name="dp_cp") | ||||||
else: | ||||||
self._all_meshes[-1]["dp"]._flatten(mesh_dim_name="dp_cp") | ||||||
|
||||||
if self.dp_cp_enabled or self.tp_enabled or self.pp_enabled: | ||||||
self._all_meshes.append( | ||||||
world_mesh._unflatten( | ||||||
0, | ||||||
non_ep_computation_dims["degree"], | ||||||
non_ep_computation_dims["name"], | ||||||
) | ||||||
) | ||||||
|
||||||
if self.ep_enabled: | ||||||
add_dim("ep", self.ep, ep_computation_dims) | ||||||
self._all_meshes.append( | ||||||
world_mesh._unflatten( | ||||||
0, ep_computation_dims["degree"], ep_computation_dims["name"] | ||||||
) | ||||||
) | ||||||
|
||||||
self._world_mesh = world_mesh | ||||||
self.mesh_dim_names = tuple( | ||||||
name for m in self._all_meshes for name in m.mesh_dim_names | ||||||
) | ||||||
return self | ||||||
|
||||||
def __getitem__(self, name): | ||||||
# This is a hack to make ParallelDims behave like a DeviceMesh. | ||||||
# We will need to change trainer if design is concluded. For now, | ||||||
# this is just a quick hack to make it work with unflatten() | ||||||
|
||||||
if "mesh_dim_names" == name: | ||||||
return [name for m in self._all_meshes for name in m.mesh_dim_names] | ||||||
|
||||||
for mesh in self._all_meshes: | ||||||
try: | ||||||
submesh = mesh[name] | ||||||
return submesh | ||||||
except KeyError: | ||||||
pass | ||||||
raise AttributeError(f"ParallelDims has no attribute {name}") | ||||||
|
||||||
""" | ||||||
def build_mesh(self) -> DeviceMesh: | ||||||
# TODO: Current implementation of ParallelDims for dp2ep Expert Parallel | ||||||
# is not very clean, due to the limited support from DeviceMesh | ||||||
|
@@ -188,14 +323,19 @@ def _build_mesh_without_ep(self) -> DeviceMesh: | |||||
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") | ||||||
|
||||||
return mesh | ||||||
""" | ||||||
|
||||||
@property | ||||||
def world_mesh(self) -> str: | ||||||
def world_mesh(self) -> "ParallelDims": | ||||||
# This is a hack to make ParallelDims behave like a DeviceMesh. | ||||||
# We will need to change trainer if design is concluded. For now, | ||||||
# this is just a quick hack to make it work with unflatten() | ||||||
|
||||||
# doing late init so ParallelDims can still be used as a lightweight | ||||||
# dataclass without having to initialize the world mesh | ||||||
if self._world_mesh is None: | ||||||
self._world_mesh = self.build_mesh() | ||||||
return self._world_mesh | ||||||
self.build_mesh() | ||||||
return self | ||||||
|
||||||
@property | ||||||
def dp_enabled(self): | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe just call it
fsdp
?