Skip to content

Commit b666844

Browse files
HollowMan6hwchen2017inkcherry
authored
Fix AutoTP gathering replaced layer params when bias is not None (#7257)
Some params are one-dimensional, this PR adds support for these params. Resolve #7249 ```log param.shape torch.Size([768, 1536]) param.shape torch.Size([768]) ... ``` ```log with deepspeed.module_inject.layers.GatherReplacedLayerParams([param], model, enabled=True): ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "deepspeed/module_inject/layers.py", line 359, in __enter__ self.params[0].gather_params(self.params) File "torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "deepspeed/module_inject/layers.py", line 473, in gather_params param.shape[1], ~~~~~~~~~~~^^^ IndexError: tuple index out of range ``` --------- Signed-off-by: Hollow Man <[email protected]> Signed-off-by: inkcherry <[email protected]> Co-authored-by: Hongwei Chen <[email protected]> Co-authored-by: inkcherry <[email protected]>
1 parent d4032ec commit b666844

File tree

2 files changed

+72
-30
lines changed

2 files changed

+72
-30
lines changed

deepspeed/module_inject/layers.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ def set_autotp_mode(training=False):
4949
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
5050

5151

52+
def add_bias(input, bias):
53+
if bias is None:
54+
return input
55+
if is_autotp_training_mode():
56+
# Training mode - avoid inplace to ensure correct autograd
57+
input = input + bias
58+
return input
59+
else:
60+
input += bias
61+
return input
62+
63+
5264
class RowParallel(torch.autograd.Function):
5365
"""
5466
A custom autograd function for performing row-wise parallelism.
@@ -92,7 +104,7 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bia
92104
ctx.group = group
93105
output = torch.matmul(input, weight.transpose(-1, -2))
94106
if bias is not None:
95-
output += bias
107+
output = add_bias(output, bias)
96108

97109
ctx.save_for_backward(input, weight)
98110

@@ -220,6 +232,14 @@ def _tp_partition(self, params_list: List[torch.Tensor]):
220232
"""
221233
pass
222234

235+
def config_requires_grad(self, weight):
236+
if weight is not None:
237+
if self.is_training_mode():
238+
if weight.requires_grad is None:
239+
weight.requires_grad = True
240+
else:
241+
weight.requires_grad = False
242+
223243
def config_tp_params(self, weight):
224244
"""
225245
Configures the weight tensor for training with tensor parallelism. This includes enabling gradients
@@ -233,15 +253,11 @@ def config_tp_params(self, weight):
233253
if self.is_training_mode():
234254
assert self.support_training, "No implementation of backward."
235255
if weight is not None:
236-
if self.is_training_mode():
237-
if weight.requires_grad is None:
238-
weight.requires_grad = True
239-
else:
240-
weight.requires_grad = False
241-
setattr(weight, DS_TENSOR_MODEL_PARALLEL, True)
242-
setattr(weight, DS_IS_REPLACED_MODULE, True)
256+
self.config_requires_grad(weight)
243257
weight.gather_params = self.gather_params
244258
weight._tp_partition = self._tp_partition
259+
setattr(weight, DS_TENSOR_MODEL_PARALLEL, True)
260+
setattr(weight, DS_IS_REPLACED_MODULE, True)
245261

246262
def is_training_mode(self):
247263
global DEEPSPEED_AUTOTP_MODE
@@ -377,13 +393,14 @@ def __init__(self, module, mp_group, **kwargs):
377393
self.support_training = True
378394
self.config_tp_params(self.weight)
379395
if self.bias is not None:
380-
self.config_tp_params(self.bias)
396+
# bias here is not tp params
397+
self.config_requires_grad(self.bias)
381398

382399
def forward(self, input):
383400
output = torch.matmul(input, self.weight.transpose(-1, -2))
384401
output = RowParallel.apply(self.mp_group, output, not self.is_training_mode())
385402
if self.bias is not None:
386-
output += self.bias
403+
output = add_bias(output, self.bias)
387404
return output
388405

389406
@torch.no_grad()
@@ -395,6 +412,7 @@ def gather_params(self, params_list):
395412
return
396413
params_list[idx].data_partition = param.data
397414
param = param.transpose(0, 1).contiguous()
415+
398416
output_param = torch.empty(self.tp_world_size * param.shape[0],
399417
param.shape[1],
400418
dtype=param.dtype,
@@ -412,9 +430,14 @@ def _tp_partition(self, params_list):
412430

413431
else:
414432
for idx, param in enumerate(params_list):
415-
if param is None or idx > 0:
433+
if param is None:
416434
# don't slipt bias
417435
return
436+
if idx > 0: # move bias to device at initialization
437+
_partition = self.move(param).detach()
438+
params_list[idx].data = _partition
439+
return
440+
418441
_partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index]
419442

420443
_partition = self.move(_partition).detach()
@@ -455,7 +478,7 @@ def forward(self, input):
455478
input = ColumnParallel.apply(self.mp_group, input)
456479
output = torch.matmul(input, self.weight.transpose(-1, -2))
457480
if self.bias is not None:
458-
output += self.bias
481+
output = add_bias(output, self.bias)
459482
else:
460483
output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias)
461484

@@ -467,8 +490,7 @@ def gather_params(self, params_list):
467490
for idx, param in enumerate(params_list):
468491

469492
params_list[idx].data_partition = param.data
470-
output_param = torch.empty(self.tp_world_size * param.shape[0],
471-
param.shape[1],
493+
output_param = torch.empty((self.tp_world_size * param.shape[0], *param.shape[1:]),
472494
dtype=param.dtype,
473495
device=param.device)
474496
dist.all_gather_into_tensor(output_param, param, group=self.mp_group)
@@ -651,7 +673,7 @@ def forward(self, input):
651673
if self.mp_group is not None:
652674
dist.inference_all_reduce(output, group=self.mp_group)
653675
if self.bias is not None:
654-
output += self.bias
676+
output = add_bias(output, self.bias)
655677
return output
656678

657679

tests/unit/model_parallelism/test_autotp_training.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, set_autotp_mode
2020
from unit.checkpoint.common import compare_lr_scheduler_states, compare_optimizer_states
2121
import os
22+
from deepspeed.runtime.utils import is_model_parallel_parameter
2223

2324

2425
def skip_on_device():
@@ -30,10 +31,9 @@ class SequentialLinearModel(torch.nn.Module):
3031

3132
def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
3233
super(SequentialLinearModel, self).__init__()
33-
self.linears = torch.nn.ModuleList(
34-
[torch.nn.Linear(hidden_dim, hidden_dim, bias=None) for i in range(nlayers)])
34+
self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)])
3535
if empty_grad:
36-
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=None)
36+
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
3737
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
3838
self.empty_grad = empty_grad
3939

@@ -153,8 +153,7 @@ def process_linear_layer(hidden_dim, input):
153153
torch_linear = nn.Linear(hidden_dim,
154154
hidden_dim,
155155
dtype=preferred_dtype(),
156-
device=get_accelerator().current_device(),
157-
bias=None)
156+
device=get_accelerator().current_device())
158157
torch_out = torch_linear(input)
159158
torch_loss = torch_out.sum()
160159
torch_loss.backward()
@@ -215,6 +214,9 @@ def testRowParallel(self, tp_size: int, tp_overlap_comm: bool):
215214
loss.backward()
216215

217216
torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()]
217+
torch_bias_grad = torch_linear.bias.grad
218+
assert torch.allclose(linear.bias.grad, torch_bias_grad.to(get_accelerator().current_device()), atol=1e-3)
219+
# The gradient of the weight is not the same as the torch_linear.weight.grad
218220
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
219221
assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-2)
220222

@@ -266,6 +268,10 @@ def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
266268

267269
cur_device_out = torch.chunk(torch_out, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()]
268270
torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()]
271+
272+
torch_bias_grad = torch.chunk(torch_linear.bias.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()]
273+
assert torch.allclose(linear.bias.grad, torch_bias_grad.to(get_accelerator().current_device()), atol=1e-3)
274+
269275
assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3)
270276
assert torch.allclose(cur_device_out.to(get_accelerator().current_device()).contiguous(),
271277
out.contiguous(),
@@ -307,23 +313,36 @@ def test(self, layer_type):
307313
model = SequentialLinearModel(hidden_dim=hidden_dim)
308314
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
309315

310-
torch_linear = nn.Linear(hidden_dim, hidden_dim, dtype=preferred_dtype(), device="cpu", bias=None)
316+
torch_linear = nn.Linear(hidden_dim, hidden_dim, dtype=preferred_dtype(), device="cpu")
311317
total_params = sum(p.numel() for p in torch_linear.parameters())
312-
313318
tp_layer = None
314319
if layer_type == "linear":
315-
tp_layer = LinearLayer(torch_linear, groups.get_tensor_model_parallel_group())
320+
tp_layer = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group())
316321
elif layer_type == "linearallreduce":
317-
tp_layer = LinearAllreduce(torch_linear, groups.get_tensor_model_parallel_group())
322+
tp_layer = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group())
318323
else:
319324
raise ValueError(f"Invalid linear type: {config_dict['linear_type']}")
320325

321326
tp_params = sum(p.numel() for p in tp_layer.parameters())
322327

323-
assert total_params // tp_size == tp_params
328+
expected_tp_params = 0
329+
# compute expected TP params:
330+
# - column-parallel (LinearLayer): weight & bias both split => total // tp_size
331+
# - row-parallel (LinearAllreduce): weight split, bias (1d tensors) replicated
332+
if layer_type == "linearallreduce":
333+
weight_params = torch_linear.weight.numel()
334+
bias_params = torch_linear.bias.numel()
335+
expected_tp_params = weight_params // tp_size + bias_params
336+
else:
337+
expected_tp_params = total_params // tp_size
338+
assert expected_tp_params == tp_params, (
339+
f"{layer_type}: expected {expected_tp_params} tp params, got {tp_params}")
340+
324341
for name, param in tp_layer.named_parameters(recurse=False):
325-
param.gather_params([param])
342+
if is_model_parallel_parameter(param):
343+
param.gather_params([param])
326344

345+
torch_linear = torch_linear.to(get_accelerator().current_device())
327346
is_same_weights = all(
328347
torch.equal(param1, param2) for param1, param2 in zip(tp_layer.parameters(), torch_linear.parameters()))
329348

@@ -333,11 +352,12 @@ def test(self, layer_type):
333352
assert total_params == params1
334353

335354
for name, param in tp_layer.named_parameters(recurse=False):
336-
param._tp_partition([param])
355+
if is_model_parallel_parameter(param):
356+
param._tp_partition([param])
337357

338358
tp_params2 = sum(p.numel() for p in tp_layer.parameters())
339359

340-
assert total_params // tp_size == tp_params2
360+
assert expected_tp_params == tp_params2
341361

342362

343363
def dummy_init_engine(config):
@@ -571,7 +591,7 @@ def test(self, tp_size: int, zero_stage: int):
571591

572592
tp_norm = tp_optimizer._global_grad_norm
573593

574-
assert math.isclose(base_norm, tp_norm, abs_tol=1e-3)
594+
assert math.isclose(base_norm, tp_norm, abs_tol=1e-3), f"base_norm: {base_norm}, tp_norm: {tp_norm}"
575595
tp_params_numel = sum(p.numel() for p in tp_model.parameters())
576596
base_params_numel = sum(p.numel() for p in base_model.parameters())
577-
assert tp_params_numel < base_params_numel
597+
assert tp_params_numel < base_params_numel, f"tp_params_numel: {tp_params_numel}, base_params_numel: {base_params_numel}"

0 commit comments

Comments
 (0)