diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index 13ead0cec38d..cba285b06423 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -58,6 +58,9 @@ def should_assert_with_msg(expected_message): pass else: raise e + else: + raise AssertionError(f"Expected AssertionError with message '{expected_message}' " + "but no exception was raised.") @pytest.mark.parametrize("tp_size", [2, 4]) @@ -67,7 +70,6 @@ class TestTpParallelStates(DistributedTest): def test(self, tp_size: int): skip_on_device() set_autotp_mode(training=True) - dp_size = 4 / tp_size hidden_dim = 128 config_dict = { @@ -88,7 +90,7 @@ def test(self, tp_size: int): @pytest.mark.parametrize("tp_size", [2, 4]) class TestTpDataloaderCorrectness(DistributedTest): world_size = 4 - reuse_dist_env = True + reuse_dist_env = False def test(self, tp_size: int): skip_on_device() @@ -117,6 +119,8 @@ def test(self, tp_size: int): model = SimpleModel(hidden_dim=hidden_dim) model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + torch.manual_seed(42) + data_loader = random_dataloader(model=model, total_samples=3, hidden_dim=hidden_dim, @@ -165,7 +169,7 @@ def process_linear_layer(hidden_dim, input): @pytest.mark.parametrize("tp_overlap_comm", [True, False]) class TestTpLayerFwdBwd(DistributedTest): world_size = 4 - reuse_dist_env = True + reuse_dist_env = False def testRowParallel(self, tp_size: int, tp_overlap_comm: bool): skip_on_device() @@ -278,10 +282,10 @@ def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool): atol=1e-2) -@pytest.mark.sequential +# @pytest.mark.sequential class TestParamsGather(DistributedTest): world_size = 4 - reuse_dist_env = True + reuse_dist_env = False @pytest.mark.parametrize("layer_type", ["linear", "linearallreduce"]) def test(self, layer_type): @@ -388,7 +392,7 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro class TestSave(DistributedTest): world_size = 4 - reuse_dist_env = True + reuse_dist_env = False def test_save_original_weight(self, tp_size: int, zero_stage: int): skip_on_device() @@ -520,7 +524,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int): class TestTpGradNorm(DistributedTest): world_size = 4 - reuse_dist_env = True + reuse_dist_env = False def test(self, tp_size: int, zero_stage: int): skip_on_device()