Skip to content

Commit d7ed534

Browse files
inkcherrysfc-gh-truwase
authored andcommitted
Fix ci hang in torch2.7& improve ut (deepspeedai#7321)
fix ci hang. improve the ut. --------- Signed-off-by: inkcherry <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 5841b54 commit d7ed534

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

tests/unit/model_parallelism/test_autotp_training.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def should_assert_with_msg(expected_message):
5858
pass
5959
else:
6060
raise e
61+
else:
62+
raise AssertionError(f"Expected AssertionError with message '{expected_message}' "
63+
"but no exception was raised.")
6164

6265

6366
@pytest.mark.parametrize("tp_size", [2, 4])
@@ -67,7 +70,6 @@ class TestTpParallelStates(DistributedTest):
6770
def test(self, tp_size: int):
6871
skip_on_device()
6972
set_autotp_mode(training=True)
70-
7173
dp_size = 4 / tp_size
7274
hidden_dim = 128
7375
config_dict = {
@@ -88,7 +90,7 @@ def test(self, tp_size: int):
8890
@pytest.mark.parametrize("tp_size", [2, 4])
8991
class TestTpDataloaderCorrectness(DistributedTest):
9092
world_size = 4
91-
reuse_dist_env = True
93+
reuse_dist_env = False
9294

9395
def test(self, tp_size: int):
9496
skip_on_device()
@@ -117,6 +119,8 @@ def test(self, tp_size: int):
117119

118120
model = SimpleModel(hidden_dim=hidden_dim)
119121
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
122+
torch.manual_seed(42)
123+
120124
data_loader = random_dataloader(model=model,
121125
total_samples=3,
122126
hidden_dim=hidden_dim,
@@ -165,7 +169,7 @@ def process_linear_layer(hidden_dim, input):
165169
@pytest.mark.parametrize("tp_overlap_comm", [True, False])
166170
class TestTpLayerFwdBwd(DistributedTest):
167171
world_size = 4
168-
reuse_dist_env = True
172+
reuse_dist_env = False
169173

170174
def testRowParallel(self, tp_size: int, tp_overlap_comm: bool):
171175
skip_on_device()
@@ -278,10 +282,10 @@ def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool):
278282
atol=1e-2)
279283

280284

281-
@pytest.mark.sequential
285+
# @pytest.mark.sequential
282286
class TestParamsGather(DistributedTest):
283287
world_size = 4
284-
reuse_dist_env = True
288+
reuse_dist_env = False
285289

286290
@pytest.mark.parametrize("layer_type", ["linear", "linearallreduce"])
287291
def test(self, layer_type):
@@ -388,7 +392,7 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro
388392
class TestSave(DistributedTest):
389393

390394
world_size = 4
391-
reuse_dist_env = True
395+
reuse_dist_env = False
392396

393397
def test_save_original_weight(self, tp_size: int, zero_stage: int):
394398
skip_on_device()
@@ -520,7 +524,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int):
520524
class TestTpGradNorm(DistributedTest):
521525

522526
world_size = 4
523-
reuse_dist_env = True
527+
reuse_dist_env = False
524528

525529
def test(self, tp_size: int, zero_stage: int):
526530
skip_on_device()

0 commit comments

Comments
 (0)