@@ -58,6 +58,9 @@ def should_assert_with_msg(expected_message):
58
58
pass
59
59
else :
60
60
raise e
61
+ else :
62
+ raise AssertionError (f"Expected AssertionError with message '{ expected_message } ' "
63
+ "but no exception was raised." )
61
64
62
65
63
66
@pytest .mark .parametrize ("tp_size" , [2 , 4 ])
@@ -67,7 +70,6 @@ class TestTpParallelStates(DistributedTest):
67
70
def test (self , tp_size : int ):
68
71
skip_on_device ()
69
72
set_autotp_mode (training = True )
70
-
71
73
dp_size = 4 / tp_size
72
74
hidden_dim = 128
73
75
config_dict = {
@@ -88,7 +90,7 @@ def test(self, tp_size: int):
88
90
@pytest .mark .parametrize ("tp_size" , [2 , 4 ])
89
91
class TestTpDataloaderCorrectness (DistributedTest ):
90
92
world_size = 4
91
- reuse_dist_env = True
93
+ reuse_dist_env = False
92
94
93
95
def test (self , tp_size : int ):
94
96
skip_on_device ()
@@ -117,6 +119,8 @@ def test(self, tp_size: int):
117
119
118
120
model = SimpleModel (hidden_dim = hidden_dim )
119
121
model , _ , _ , _ = deepspeed .initialize (model = model , model_parameters = model .parameters (), config = config_dict )
122
+ torch .manual_seed (42 )
123
+
120
124
data_loader = random_dataloader (model = model ,
121
125
total_samples = 3 ,
122
126
hidden_dim = hidden_dim ,
@@ -165,7 +169,7 @@ def process_linear_layer(hidden_dim, input):
165
169
@pytest .mark .parametrize ("tp_overlap_comm" , [True , False ])
166
170
class TestTpLayerFwdBwd (DistributedTest ):
167
171
world_size = 4
168
- reuse_dist_env = True
172
+ reuse_dist_env = False
169
173
170
174
def testRowParallel (self , tp_size : int , tp_overlap_comm : bool ):
171
175
skip_on_device ()
@@ -278,10 +282,10 @@ def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
278
282
atol = 1e-2 )
279
283
280
284
281
- @pytest .mark .sequential
285
+ # @pytest.mark.sequential
282
286
class TestParamsGather (DistributedTest ):
283
287
world_size = 4
284
- reuse_dist_env = True
288
+ reuse_dist_env = False
285
289
286
290
@pytest .mark .parametrize ("layer_type" , ["linear" , "linearallreduce" ])
287
291
def test (self , layer_type ):
@@ -388,7 +392,7 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro
388
392
class TestSave (DistributedTest ):
389
393
390
394
world_size = 4
391
- reuse_dist_env = True
395
+ reuse_dist_env = False
392
396
393
397
def test_save_original_weight (self , tp_size : int , zero_stage : int ):
394
398
skip_on_device ()
@@ -520,7 +524,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int):
520
524
class TestTpGradNorm (DistributedTest ):
521
525
522
526
world_size = 4
523
- reuse_dist_env = True
527
+ reuse_dist_env = False
524
528
525
529
def test (self , tp_size : int , zero_stage : int ):
526
530
skip_on_device ()
0 commit comments