@@ -241,7 +241,7 @@ def _make_block(self):
241
241
242
242
243
243
@RunIf (min_cuda_gpus = 2 , standalone = True , deepspeed = True )
244
- def test_deepspeed_multigpu_stage_3 (tmpdir ):
244
+ def test_deepspeed_multigpu_stage_3 ():
245
245
"""Test to ensure ZeRO Stage 3 works with a parallel model."""
246
246
fabric = ModelParallelClassification (
247
247
strategy = DeepSpeedStrategy (stage = 3 ),
@@ -255,7 +255,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir):
255
255
@RunIf (deepspeed = True )
256
256
@mock .patch ("deepspeed.init_distributed" , autospec = True )
257
257
@pytest .mark .parametrize ("platform" , ["Linux" , "Windows" ])
258
- def test_deepspeed_env_variables_on_platforms (deepspeed_dist_mock , tmpdir , platform ):
258
+ def test_deepspeed_env_variables_on_platforms (deepspeed_dist_mock , platform ):
259
259
"""Test to ensure that we set up distributed communication correctly.
260
260
261
261
When using Windows, ranks environment variables should not be set, and DeepSpeed should handle this.
@@ -279,7 +279,7 @@ def test_deepspeed_env_variables_on_platforms(deepspeed_dist_mock, tmpdir, platf
279
279
280
280
281
281
@RunIf (min_cuda_gpus = 2 , standalone = True , deepspeed = True )
282
- def test_deepspeed_specific_gpu_device_index (tmpdir ):
282
+ def test_deepspeed_specific_gpu_device_index ():
283
283
"""Test that the DeepSpeed strategy can run on specific device indices."""
284
284
285
285
class RunFabric (BoringFabric ):
@@ -295,7 +295,7 @@ def step(self, model, batch):
295
295
296
296
297
297
@RunIf (min_cuda_gpus = 2 , standalone = True , deepspeed = True , bf16_cuda = True )
298
- def test_deepspeed_with_bfloat16_precision (tmpdir ):
298
+ def test_deepspeed_with_bfloat16_precision ():
299
299
"""Test that the DeepSpeed strategy works with bfloat16 precision."""
300
300
301
301
class Model (nn .Module ):
@@ -322,3 +322,88 @@ def step(self, model, batch):
322
322
assert fabric ._strategy .precision .precision == "bf16"
323
323
assert fabric ._strategy .config ["zero_optimization" ]["stage" ] == 3
324
324
fabric .run ()
325
+
326
+
327
+ def _assert_saved_model_is_equal (fabric , model , checkpoint_path ):
328
+ """Convert the saved checkpoint to a single file with the model weights consolidated to easily verify the full
329
+ weights in float32 precision."""
330
+ from deepspeed .utils .zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
331
+
332
+ assert isinstance (fabric .strategy , DeepSpeedStrategy )
333
+
334
+ # carry out the check only on rank 0
335
+ if fabric .is_global_zero :
336
+ if fabric .strategy .config ["zero_optimization" ]["stage" ] in (2 , 3 ):
337
+ single_ckpt_path = checkpoint_path / "single_model.pt"
338
+ # the tag is hardcoded in DeepSpeedStrategy
339
+ convert_zero_checkpoint_to_fp32_state_dict (checkpoint_path , single_ckpt_path , tag = "checkpoint" )
340
+ state_dict = torch .load (single_ckpt_path )
341
+ else :
342
+ # 'checkpoint' is the tag, hardcoded in DeepSpeedStrategy
343
+ single_ckpt_path = checkpoint_path / "checkpoint" / "mp_rank_00_model_states.pt"
344
+ state_dict = torch .load (single_ckpt_path )["module" ]
345
+
346
+ model = model .cpu ()
347
+
348
+ # assert model parameters are identical after loading
349
+ for orig_param , saved_model_param in zip (model .parameters (), state_dict .values ()):
350
+ # perform the equality check in the same precision
351
+ saved_model_param = saved_model_param .cpu ().to (orig_param .dtype )
352
+ assert torch .equal (orig_param , saved_model_param )
353
+
354
+ fabric .barrier ()
355
+
356
+
357
+ @RunIf (min_cuda_gpus = 2 , standalone = True , deepspeed = True , bf16_cuda = True )
358
+ @pytest .mark .parametrize ("stage" , [1 , 2 , 3 ])
359
+ def test_deepspeed_save_load_checkpoint_zero_3 (stage , tmp_path ):
360
+ """Test that DeepSpeed stage 1, 2, and 3 model checkpoints can be saved and loaded successfully."""
361
+ from deepspeed import DeepSpeedEngine
362
+
363
+ fabric = Fabric (accelerator = "cuda" , devices = 2 , strategy = DeepSpeedStrategy (stage = stage ), precision = "bf16" )
364
+ fabric .launch ()
365
+
366
+ checkpoint_path = fabric .broadcast (tmp_path / "deepspeed-checkpoint" )
367
+
368
+ with fabric .sharded_model ():
369
+ model = BoringModel ()
370
+
371
+ optimizer = torch .optim .SGD (model .parameters (), lr = 0.0001 )
372
+ model , optimizer = fabric .setup (model , optimizer )
373
+ assert isinstance (model ._forward_module , DeepSpeedEngine )
374
+
375
+ # TODO(fabric): The dtype on the model is not correct, should be torch.bfloat16
376
+ assert model .dtype == torch .float32
377
+ assert next (model .parameters ()).dtype == torch .bfloat16
378
+
379
+ # dummy training step
380
+ output = model (torch .randn (1 , 32 ).to (fabric .device ))
381
+ loss = output .sum ()
382
+ fabric .backward (loss )
383
+ optimizer .step ()
384
+ optimizer .zero_grad ()
385
+
386
+ state = {"model" : model , "optimizer" : optimizer , "steps" : 1 }
387
+ fabric .save (checkpoint_path , state )
388
+
389
+ fabric .barrier ()
390
+
391
+ # re-init all objects and resume
392
+ fabric = Fabric (accelerator = "cuda" , devices = 2 , strategy = DeepSpeedStrategy (stage = stage ), precision = "bf16" )
393
+ fabric .launch ()
394
+ with fabric .sharded_model ():
395
+ model = BoringModel ()
396
+
397
+ optimizer = torch .optim .SGD (model .parameters (), lr = 0.0001 )
398
+ model , optimizer = fabric .setup (model , optimizer )
399
+ state = {"model" : model , "optimizer" : optimizer , "steps" : 0 }
400
+
401
+ metadata = fabric .load (checkpoint_path , state )
402
+ fabric .barrier ()
403
+
404
+ # check user data in state reloaded
405
+ assert state ["steps" ] == 1
406
+ # the remainder of the deepspeed checkpoint contains metadata
407
+ assert "ds_version" in metadata
408
+
409
+ _assert_saved_model_is_equal (fabric , model , checkpoint_path )
0 commit comments