Skip to content

Commit 185330c

Browse files
limjcstMingjie Litohtana
authored
Support complicated use cases with TiedLayerSpec (#7208)
I want to reuse a composed module in the pipeline. For example, the following `MyModule` has a member `linear`, which is also a module. ```python class MyModule(torch.nn.Module): def __init__(self, n_in: int, n_out: int): super().__init__() self.linear = torch.nn.Linear(n_in, n_out) self.layer_norm = torch.nn.LayerNorm(n_out) def forward(self, data: torch.Tensor) -> torch.Tensor: hidden = self.linear(data) hidden = self.layer_norm(hidden) return hidden ``` `MyModule.linear.weight` should be synchronized among related ranks. As a result, I add `linear.weight` to `TiedLayerSpec.tied_weight_attr`. BTW, I generate the whole `tied_weight_attr` by the following instruction. ```python tied_weight_attr = [name for name, p in layer.named_parameters() if p.numel() > 1] ``` However, the builtin `getattr` used by `PipelineModule` fails to find a nested attribute like `linear.weight`. Hence, this PR first extends the builtin `getattr` to a recursive version `PipelineModule._recursive_getattr`, accessing each attribute segment one by one. Meanwhile, the order of tied weights matters in synchronization. This PR suggests to sort tie_keys in `PipelineModule._index_tied_modules` to avoid hanging. Signed-off-by: Mingjie Li <[email protected]> Co-authored-by: Mingjie Li <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]>
1 parent 56005d2 commit 185330c

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

deepspeed/runtime/pipe/module.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,26 +443,34 @@ def _partition_layers(self, method='uniform'):
443443

444444
self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])
445445

446+
@staticmethod
447+
def _recursive_getattr(module: torch.nn.Module, attr_name: str) -> torch.Tensor:
448+
'''Allow getting an attribute like "linear.weight"'''
449+
weight = module
450+
for item in attr_name.split("."):
451+
weight = getattr(weight, item)
452+
return weight
453+
446454
def allreduce_tied_weight_gradients(self):
447455
'''All reduce the gradients of the tied weights between tied stages'''
448456
for key, comm in self.tied_comms.items():
449457
for attr_name in comm['weight_attr']:
450-
weight = getattr(self.tied_modules[key], attr_name)
458+
weight = self._recursive_getattr(self.tied_modules[key], attr_name)
451459
dist.all_reduce(weight.grad, group=comm['group'])
452460

453461
def get_tied_weights_and_groups(self):
454462
weight_group_list = []
455463
for key, comm in self.tied_comms.items():
456464
for attr_name in comm['weight_attr']:
457-
weight = getattr(self.tied_modules[key], attr_name)
465+
weight = self._recursive_getattr(self.tied_modules[key], attr_name)
458466
weight_group_list.append((weight, comm['group']))
459467
return weight_group_list
460468

461469
def _synchronize_tied_weights(self):
462470
for key, comm in self.tied_comms.items():
463471
for attr_name in comm['weight_attr']:
464472
dist.broadcast(
465-
getattr(comm['module'], attr_name),
473+
self._recursive_getattr(comm['module'], attr_name),
466474
src=min(comm['ranks']),
467475
group=comm['group'],
468476
)
@@ -475,7 +483,10 @@ def _index_tied_modules(self):
475483

476484
specs = self._layer_specs
477485
tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec))
478-
for key in tie_keys:
486+
# Since Python 3.7, "Dictionary order is guaranteed to be insertion order."
487+
# Sort tie_keys here so that orders of self.tied_comms.items() are consistent
488+
# among ranks.
489+
for key in sorted(tie_keys):
479490
# Find the layers that the tied module appears in
480491
tied_layers = []
481492
for idx, layer in enumerate(specs):

0 commit comments

Comments
 (0)