19
19
from deepspeed .module_inject .layers import LinearAllreduce , LinearLayer , set_autotp_mode
20
20
from unit .checkpoint .common import compare_lr_scheduler_states , compare_optimizer_states
21
21
import os
22
+ from deepspeed .runtime .utils import is_model_parallel_parameter
22
23
23
24
24
25
def skip_on_device ():
@@ -30,10 +31,9 @@ class SequentialLinearModel(torch.nn.Module):
30
31
31
32
def __init__ (self , hidden_dim , empty_grad = False , nlayers = 1 ):
32
33
super (SequentialLinearModel , self ).__init__ ()
33
- self .linears = torch .nn .ModuleList (
34
- [torch .nn .Linear (hidden_dim , hidden_dim , bias = None ) for i in range (nlayers )])
34
+ self .linears = torch .nn .ModuleList ([torch .nn .Linear (hidden_dim , hidden_dim ) for _ in range (nlayers )])
35
35
if empty_grad :
36
- self .linear2 = torch .nn .Linear (hidden_dim , hidden_dim , bias = None )
36
+ self .linear2 = torch .nn .Linear (hidden_dim , hidden_dim )
37
37
self .cross_entropy_loss = torch .nn .CrossEntropyLoss ()
38
38
self .empty_grad = empty_grad
39
39
@@ -153,8 +153,7 @@ def process_linear_layer(hidden_dim, input):
153
153
torch_linear = nn .Linear (hidden_dim ,
154
154
hidden_dim ,
155
155
dtype = preferred_dtype (),
156
- device = get_accelerator ().current_device (),
157
- bias = None )
156
+ device = get_accelerator ().current_device ())
158
157
torch_out = torch_linear (input )
159
158
torch_loss = torch_out .sum ()
160
159
torch_loss .backward ()
@@ -215,6 +214,9 @@ def testRowParallel(self, tp_size: int, tp_overlap_comm: bool):
215
214
loss .backward ()
216
215
217
216
torch_grad = torch .chunk (torch_linear .weight .grad , tp_size , dim = 1 )[groups .get_tensor_model_parallel_rank ()]
217
+ torch_bias_grad = torch_linear .bias .grad
218
+ assert torch .allclose (linear .bias .grad , torch_bias_grad .to (get_accelerator ().current_device ()), atol = 1e-3 )
219
+ # The gradient of the weight is not the same as the torch_linear.weight.grad
218
220
assert torch .allclose (linear .weight .grad , torch_grad .to (get_accelerator ().current_device ()), atol = 1e-3 )
219
221
assert torch .allclose (out , torch_out .to (get_accelerator ().current_device ()), atol = 1e-2 )
220
222
@@ -266,6 +268,10 @@ def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
266
268
267
269
cur_device_out = torch .chunk (torch_out , tp_size , dim = - 1 )[groups .get_tensor_model_parallel_rank ()]
268
270
torch_grad = torch .chunk (torch_linear .weight .grad , tp_size , dim = 0 )[groups .get_tensor_model_parallel_rank ()]
271
+
272
+ torch_bias_grad = torch .chunk (torch_linear .bias .grad , tp_size , dim = 0 )[groups .get_tensor_model_parallel_rank ()]
273
+ assert torch .allclose (linear .bias .grad , torch_bias_grad .to (get_accelerator ().current_device ()), atol = 1e-3 )
274
+
269
275
assert torch .allclose (linear .weight .grad , torch_grad .to (get_accelerator ().current_device ()), atol = 1e-3 )
270
276
assert torch .allclose (cur_device_out .to (get_accelerator ().current_device ()).contiguous (),
271
277
out .contiguous (),
@@ -307,23 +313,36 @@ def test(self, layer_type):
307
313
model = SequentialLinearModel (hidden_dim = hidden_dim )
308
314
model , _ , _ , _ = deepspeed .initialize (model = model , model_parameters = model .parameters (), config = config_dict )
309
315
310
- torch_linear = nn .Linear (hidden_dim , hidden_dim , dtype = preferred_dtype (), device = "cpu" , bias = None )
316
+ torch_linear = nn .Linear (hidden_dim , hidden_dim , dtype = preferred_dtype (), device = "cpu" )
311
317
total_params = sum (p .numel () for p in torch_linear .parameters ())
312
-
313
318
tp_layer = None
314
319
if layer_type == "linear" :
315
- tp_layer = LinearLayer (torch_linear , groups .get_tensor_model_parallel_group ())
320
+ tp_layer = LinearLayer (deepcopy ( torch_linear ) , groups .get_tensor_model_parallel_group ())
316
321
elif layer_type == "linearallreduce" :
317
- tp_layer = LinearAllreduce (torch_linear , groups .get_tensor_model_parallel_group ())
322
+ tp_layer = LinearAllreduce (deepcopy ( torch_linear ) , groups .get_tensor_model_parallel_group ())
318
323
else :
319
324
raise ValueError (f"Invalid linear type: { config_dict ['linear_type' ]} " )
320
325
321
326
tp_params = sum (p .numel () for p in tp_layer .parameters ())
322
327
323
- assert total_params // tp_size == tp_params
328
+ expected_tp_params = 0
329
+ # compute expected TP params:
330
+ # - column-parallel (LinearLayer): weight & bias both split => total // tp_size
331
+ # - row-parallel (LinearAllreduce): weight split, bias (1d tensors) replicated
332
+ if layer_type == "linearallreduce" :
333
+ weight_params = torch_linear .weight .numel ()
334
+ bias_params = torch_linear .bias .numel ()
335
+ expected_tp_params = weight_params // tp_size + bias_params
336
+ else :
337
+ expected_tp_params = total_params // tp_size
338
+ assert expected_tp_params == tp_params , (
339
+ f"{ layer_type } : expected { expected_tp_params } tp params, got { tp_params } " )
340
+
324
341
for name , param in tp_layer .named_parameters (recurse = False ):
325
- param .gather_params ([param ])
342
+ if is_model_parallel_parameter (param ):
343
+ param .gather_params ([param ])
326
344
345
+ torch_linear = torch_linear .to (get_accelerator ().current_device ())
327
346
is_same_weights = all (
328
347
torch .equal (param1 , param2 ) for param1 , param2 in zip (tp_layer .parameters (), torch_linear .parameters ()))
329
348
@@ -333,11 +352,12 @@ def test(self, layer_type):
333
352
assert total_params == params1
334
353
335
354
for name , param in tp_layer .named_parameters (recurse = False ):
336
- param ._tp_partition ([param ])
355
+ if is_model_parallel_parameter (param ):
356
+ param ._tp_partition ([param ])
337
357
338
358
tp_params2 = sum (p .numel () for p in tp_layer .parameters ())
339
359
340
- assert total_params // tp_size == tp_params2
360
+ assert expected_tp_params == tp_params2
341
361
342
362
343
363
def dummy_init_engine (config ):
@@ -571,7 +591,7 @@ def test(self, tp_size: int, zero_stage: int):
571
591
572
592
tp_norm = tp_optimizer ._global_grad_norm
573
593
574
- assert math .isclose (base_norm , tp_norm , abs_tol = 1e-3 )
594
+ assert math .isclose (base_norm , tp_norm , abs_tol = 1e-3 ), f"base_norm: { base_norm } , tp_norm: { tp_norm } "
575
595
tp_params_numel = sum (p .numel () for p in tp_model .parameters ())
576
596
base_params_numel = sum (p .numel () for p in base_model .parameters ())
577
- assert tp_params_numel < base_params_numel
597
+ assert tp_params_numel < base_params_numel , f"tp_params_numel: { tp_params_numel } , base_params_numel: { base_params_numel } "
0 commit comments