Skip to content

Commit 024d6dc

Browse files
committed
Fix tests
1 parent 78ec378 commit 024d6dc

File tree

7 files changed

+79
-60
lines changed

7 files changed

+79
-60
lines changed

src/pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
@functools.lru_cache(maxsize=1)
3333
def _import_amp_without_deprecation() -> ModuleType:
3434
# hide the warning upstream in favor of our deprecation
35-
with warnings.filterwarnings(action="ignore", message="apex.amp is deprecated", category=FutureWarning):
36-
from apex import amp
35+
warnings.filterwarnings(action="ignore", message="apex.amp is deprecated", category=FutureWarning)
36+
from apex import amp
3737

38-
return amp
38+
return amp
3939

4040

4141
# TODO: remove in v1.10.0

src/pytorch_lightning/plugins/precision/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def __init__(self, precision: Union[str, int], amp_type: str = "native", amp_lev
6666
f"`{type(self).__name__}(amp_level={amp_level!r})` is only relevant when using NVIDIA/apex"
6767
)
6868
rank_zero_deprecation(
69-
f"Passing `{type(self).__name__}(amp_type=...)` been deprecated in v1.9.0 and will be removed in"
70-
" v1.10.0. This argument is no longer necessary."
69+
f"Passing `{type(self).__name__}(amp_type={amp_type!r})` been deprecated in v1.9.0 and will be removed"
70+
f" in v1.10.0. This argument is no longer necessary."
7171
)
7272

7373
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT)

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
FullyShardedNativeMixedPrecisionPlugin,
4949
HPUPrecisionPlugin,
5050
IPUPrecisionPlugin,
51-
NativeMixedPrecisionPlugin,
51+
MixedPrecisionPlugin,
5252
PLUGIN_INPUT,
5353
PrecisionPlugin,
5454
ShardedNativeMixedPrecisionPlugin,
@@ -717,7 +717,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
717717

718718
if self._precision_flag in (16, "bf16"):
719719
rank_zero_info(
720-
f"Using 16bit {self._amp_type_flag.value} Automatic Mixed Precision (AMP)" # type: ignore
720+
f"Using 16bit {self._amp_type_flag} Automatic Mixed Precision (AMP)" # type: ignore
721721
if self._precision_flag == 16
722722
else "Using bfloat16 Automatic Mixed Precision (AMP)"
723723
)
@@ -731,7 +731,7 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
731731
return FullyShardedNativeNativeMixedPrecisionPlugin(self._precision_flag, device)
732732
if isinstance(self.strategy, DDPFullyShardedStrategy):
733733
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
734-
return NativeMixedPrecisionPlugin(self._precision_flag, device)
734+
return MixedPrecisionPlugin(self._precision_flag, device)
735735

736736
if self._amp_type_flag == "apex":
737737
self._amp_level_flag = self._amp_level_flag or "O2"
@@ -771,7 +771,7 @@ def _validate_precision_choice(self) -> None:
771771
)
772772
if self._precision_flag == "bf16" and self._amp_type_flag != "native":
773773
raise MisconfigurationException(
774-
f"You passed `Trainer(amp_type={self._amp_type_flag.value!r}, precision='bf16')` but " # type: ignore
774+
f"You passed `Trainer(amp_type={self._amp_type_flag!r}, precision='bf16')` but " # type: ignore
775775
"it's not supported. Try using `amp_type='native'` instead."
776776
)
777777
if self._precision_flag in (16, "bf16") and self._amp_type_flag == "apex":

tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121

2222
def test_invalid_precision_with_deepspeed_precision():
23-
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
23+
with pytest.deprecated_call(match=r"amp_type='native'\)` been deprecated in v1.9.0"), pytest.raises(
24+
ValueError, match="is not supported. `precision` must be one of"
25+
):
2426
DeepSpeedPrecisionPlugin(precision=64, amp_type="native")
2527

2628

tests/tests_pytorch/plugins/precision/test_native_amp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
import pytest
1717
from torch.optim import Optimizer
1818

19-
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
19+
from pytorch_lightning.plugins import MixedPrecisionPlugin
2020
from pytorch_lightning.utilities import GradClipAlgorithmType
2121

2222

2323
def test_clip_gradients():
2424
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
2525
optimizer = Mock(spec=Optimizer)
26-
precision = NativeMixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock())
26+
precision = MixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock())
2727
precision.clip_grad_by_value = Mock()
2828
precision.clip_grad_by_norm = Mock()
2929
precision.clip_gradients(optimizer)
@@ -47,7 +47,7 @@ def test_optimizer_amp_scaling_support_in_step_method():
4747
gradient clipping (example: fused Adam)."""
4848

4949
optimizer = Mock(_step_supports_amp_scaling=True)
50-
precision = NativeMixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock())
50+
precision = MixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock())
5151

5252
with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
5353
precision.clip_gradients(optimizer, clip_val=1.0)

tests/tests_pytorch/plugins/test_amp_plugins.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020

2121
from pytorch_lightning import Trainer
2222
from pytorch_lightning.demos.boring_classes import BoringModel
23-
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
23+
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, MixedPrecisionPlugin
2424
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2525
from tests_pytorch.conftest import mock_cuda_count
2626
from tests_pytorch.helpers.runif import RunIf
2727

2828

29-
class MyNativeAMP(NativeMixedPrecisionPlugin):
29+
class MyNativeAMP(MixedPrecisionPlugin):
3030
pass
3131

3232

@@ -52,7 +52,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
5252
@pytest.mark.parametrize(
5353
"amp,custom_plugin,plugin_cls",
5454
[
55-
("native", False, NativeMixedPrecisionPlugin),
55+
("native", False, MixedPrecisionPlugin),
5656
("native", True, MyNativeAMP),
5757
pytest.param("apex", False, ApexMixedPrecisionPlugin, marks=RunIf(amp_apex=True)),
5858
pytest.param("apex", True, MyApexPlugin, marks=RunIf(amp_apex=True)),
@@ -189,9 +189,7 @@ def configure_optimizers(self):
189189
torch.optim.SGD(self.layer2.parameters(), lr=0.1),
190190
]
191191

192-
trainer = Trainer(
193-
default_root_dir=tmpdir, accelerator="gpu", devices=1, fast_dev_run=1, amp_backend="native", precision=16
194-
)
192+
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=1, fast_dev_run=1, precision=16)
195193
model = CustomBoringModel()
196194
trainer.fit(model)
197195

@@ -246,7 +244,7 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
246244

247245
def test_cpu_amp_precision_context_manager(tmpdir):
248246
"""Test to ensure that the context manager correctly is set to CPU + bfloat16."""
249-
plugin = NativeMixedPrecisionPlugin("bf16", "cpu")
247+
plugin = MixedPrecisionPlugin("bf16", "cpu")
250248
assert plugin.device == "cpu"
251249
assert plugin.scaler is None
252250
context_manager = plugin.autocast_context_manager()
@@ -256,16 +254,20 @@ def test_cpu_amp_precision_context_manager(tmpdir):
256254

257255

258256
def test_precision_selection_raises(monkeypatch):
259-
with pytest.raises(
257+
with pytest.deprecated_call(match=r"amp_backend='apex'\)` argument is deprecated"), pytest.raises(
260258
MisconfigurationException, match=r"precision=16, amp_type='apex'\)` but apex AMP not supported on CPU"
261259
):
262260
Trainer(amp_backend="apex", precision=16)
263261

264-
with pytest.raises(MisconfigurationException, match=r"amp_type='apex', precision='bf16'\)` but it's not supported"):
262+
with pytest.deprecated_call(match=r"amp_backend='apex'\)` argument is deprecated"), pytest.raises(
263+
MisconfigurationException, match=r"amp_type='apex', precision='bf16'\)` but it's not supported"
264+
):
265265
Trainer(amp_backend="apex", precision="bf16")
266266

267267
mock_cuda_count(monkeypatch, 1)
268-
with pytest.raises(MisconfigurationException, match="Sharded plugins are not supported with apex"):
268+
with pytest.deprecated_call(match=r"amp_backend='apex'\)` argument is deprecated"), pytest.raises(
269+
MisconfigurationException, match="Sharded plugins are not supported with apex"
270+
):
269271
with mock.patch("lightning_lite.accelerators.cuda.is_cuda_available", return_value=True):
270272
Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1, strategy="ddp_fully_sharded")
271273

@@ -274,5 +276,5 @@ def test_precision_selection_raises(monkeypatch):
274276
monkeypatch.setattr(apex, "_APEX_AVAILABLE", False)
275277
with mock.patch("lightning_lite.accelerators.cuda.is_cuda_available", return_value=True), pytest.raises(
276278
MisconfigurationException, match="asked for Apex AMP but `apex` is not installed"
277-
), pytest.deprecated_call(match="apex AMP implementation has been deprecated"):
279+
), pytest.deprecated_call(match=r"amp_backend='apex'\)` argument is deprecated"):
278280
Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1)

tests/tests_pytorch/trainer/optimization/test_manual_optimization.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,51 @@ def configure_optimizers(self):
6565

6666

6767
@pytest.mark.parametrize(
68-
"kwargs",
69-
[
70-
{},
71-
pytest.param(
72-
{"accelerator": "gpu", "devices": 1, "precision": 16, "amp_backend": "native"}, marks=RunIf(min_cuda_gpus=1)
73-
),
74-
pytest.param(
75-
{"accelerator": "gpu", "devices": 1, "precision": 16, "amp_backend": "apex"},
76-
marks=RunIf(min_cuda_gpus=1, amp_apex=True),
77-
),
78-
],
68+
"kwargs", [{}, pytest.param({"accelerator": "gpu", "devices": 1, "precision": 16}, marks=RunIf(min_cuda_gpus=1))]
7969
)
8070
def test_multiple_optimizers_manual_no_return(tmpdir, kwargs):
71+
class TestModel(ManualOptModel):
72+
def training_step(self, batch, batch_idx):
73+
# avoid returning a value
74+
super().training_step(batch, batch_idx)
75+
76+
def training_epoch_end(self, outputs):
77+
# outputs is empty as training_step does not return
78+
# and it is not automatic optimization
79+
assert not outputs
80+
81+
model = TestModel()
82+
model.val_dataloader = None
83+
84+
limit_train_batches = 2
85+
trainer = Trainer(
86+
default_root_dir=tmpdir,
87+
limit_train_batches=limit_train_batches,
88+
limit_val_batches=2,
89+
max_epochs=1,
90+
log_every_n_steps=1,
91+
enable_model_summary=False,
92+
**kwargs,
93+
)
94+
95+
if kwargs.get("precision") == 16:
96+
# mock the scaler instead of the optimizer step because it can be skipped with NaNs
97+
scaler_step_patch = mock.patch.object(
98+
trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step
99+
)
100+
scaler_step = scaler_step_patch.start()
101+
102+
with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock:
103+
trainer.fit(model)
104+
assert bwd_mock.call_count == limit_train_batches * 3
105+
106+
if kwargs.get("precision") == 16:
107+
scaler_step_patch.stop()
108+
assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches
109+
110+
111+
@RunIf(min_cuda_gpus=1, amp_apex=True)
112+
def test_multiple_optimizers_manual_no_return_apex(tmpdir):
81113
apex_optimizer_patches = []
82114
apex_optimizer_steps = []
83115

@@ -92,8 +124,6 @@ def training_epoch_end(self, outputs):
92124
assert not outputs
93125

94126
def on_train_start(self):
95-
if kwargs.get("amp_backend") != "apex":
96-
return
97127
# extremely ugly. APEX patches all the native torch optimizers on `_initialize` which we call on
98128
# `ApexMixedPrecisionPlugin.dispatch`. Additionally, their replacement `new_step` functions are locally
99129
# defined so can't even patch those, thus we need to create the mock after APEX has been initialized
@@ -106,19 +136,15 @@ def on_train_start(self):
106136
apex_optimizer_steps.append(patch.start())
107137

108138
def on_train_end(self):
109-
if kwargs.get("amp_backend") == "apex":
110-
for p in apex_optimizer_patches:
111-
p.stop()
139+
for p in apex_optimizer_patches:
140+
p.stop()
112141

113142
model = TestModel()
114143
model.val_dataloader = None
115144

116145
limit_train_batches = 2
117-
plugins = []
118-
if kwargs.get("amp_backend") == "apex":
119-
with pytest.deprecated_call(match="apex AMP implementation has been deprecated"):
120-
apex_plugin = ApexMixedPrecisionPlugin(amp_level="O2")
121-
plugins.append(apex_plugin)
146+
with pytest.deprecated_call(match="apex AMP implementation has been deprecated"):
147+
plugins = [ApexMixedPrecisionPlugin(amp_level="O2")]
122148

123149
trainer = Trainer(
124150
default_root_dir=tmpdir,
@@ -128,25 +154,16 @@ def on_train_end(self):
128154
log_every_n_steps=1,
129155
enable_model_summary=False,
130156
plugins=plugins,
131-
**kwargs,
157+
accelerator="gpu",
158+
devices=1,
159+
precision=16,
132160
)
133161

134-
if kwargs.get("amp_backend") == "native":
135-
# mock the scaler instead of the optimizer step because it can be skipped with NaNs
136-
scaler_step_patch = mock.patch.object(
137-
trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step
138-
)
139-
scaler_step = scaler_step_patch.start()
140-
141162
with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock:
142163
trainer.fit(model)
143164
assert bwd_mock.call_count == limit_train_batches * 3
144165

145-
if kwargs.get("amp_backend") == "native":
146-
scaler_step_patch.stop()
147-
assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches
148-
if kwargs.get("amp_backend") == "apex":
149-
assert [s.call_count for s in apex_optimizer_steps] == [len(model.optimizers())] * limit_train_batches
166+
assert [s.call_count for s in apex_optimizer_steps] == [len(model.optimizers())] * limit_train_batches
150167

151168

152169
def test_multiple_optimizers_manual_return(tmpdir):
@@ -396,7 +413,6 @@ def on_train_epoch_end(self, *_, **__):
396413
limit_test_batches=0,
397414
limit_val_batches=0,
398415
precision=16,
399-
amp_backend="native",
400416
accelerator="gpu",
401417
devices=1,
402418
)
@@ -480,7 +496,6 @@ def log_grad_norm(self, grad_norm_dict):
480496
log_every_n_steps=1,
481497
enable_model_summary=False,
482498
precision=16,
483-
amp_backend="native",
484499
accelerator="gpu",
485500
devices=1,
486501
track_grad_norm=2,

0 commit comments

Comments
 (0)