23
23
24
24
def test_invalid_precision_with_deepspeed_precision ():
25
25
with pytest .raises (ValueError , match = "is not supported in DeepSpeed. `precision` must be one of" ):
26
- DeepSpeedPrecision (precision = 64 , amp_type = "native" )
27
-
28
-
29
- def test_deepspeed_precision_apex_not_installed (monkeypatch ):
30
- import lightning_lite .plugins .precision .deepspeed as deepspeed
31
-
32
- monkeypatch .setattr (deepspeed , "_APEX_AVAILABLE" , False )
33
- with pytest .raises (ImportError , match = "You have asked for Apex AMP but `apex` is not installed." ):
34
- DeepSpeedPrecision (precision = 16 , amp_type = "apex" )
35
-
36
-
37
- @mock .patch ("lightning_lite.plugins.precision.deepspeed._APEX_AVAILABLE" , return_value = True )
38
- def test_deepspeed_precision_apex_default_level (_ ):
39
- with pytest .deprecated_call (match = "apex AMP implementation has been deprecated" ):
40
- precision = DeepSpeedPrecision (precision = 16 , amp_type = "apex" )
41
- assert isinstance (precision , DeepSpeedPrecision )
42
- assert precision .amp_level == "O2"
26
+ DeepSpeedPrecision (precision = 64 )
43
27
44
28
45
29
def test_deepspeed_precision_backward ():
46
- precision = DeepSpeedPrecision (precision = 32 , amp_type = "native" )
30
+ precision = DeepSpeedPrecision (precision = 32 )
47
31
tensor = Mock ()
48
32
model = Mock ()
49
33
precision .backward (tensor , model , "positional-arg" , keyword = "arg" )
@@ -61,7 +45,7 @@ def test_deepspeed_engine_is_steppable(engine):
61
45
62
46
63
47
def test_deepspeed_precision_optimizer_step ():
64
- precision = DeepSpeedPrecision (precision = 32 , amp_type = "native" )
48
+ precision = DeepSpeedPrecision (precision = 32 )
65
49
optimizer = model = Mock ()
66
50
precision .optimizer_step (optimizer , lr_kwargs = dict ())
67
51
model .step .assert_called_once_with (lr_kwargs = dict ())
0 commit comments