Skip to content

Commit e24f958

Browse files
committed
Fixes
1 parent 645db8c commit e24f958

File tree

10 files changed

+22
-37
lines changed

10 files changed

+22
-37
lines changed

src/pytorch_lightning/lite/lite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from pytorch_lightning.accelerators import Accelerator as PLAccelerator
3939
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin as PLDeepSpeedPrecisionPlugin
4040
from pytorch_lightning.plugins import DoublePrecisionPlugin as PLDoublePrecisionPlugin
41-
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin as PLNativeMixedPrecisionPlugin
41+
from pytorch_lightning.plugins import MixedPrecisionPlugin as PLMixedPrecisionPlugin
4242
from pytorch_lightning.plugins import PrecisionPlugin as PLPrecisionPlugin
4343
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin as PLTPUBf16PrecisionPlugin
4444
from pytorch_lightning.plugins import TPUPrecisionPlugin as PLTPUPrecisionPlugin
@@ -284,7 +284,7 @@ def _to_lite_precision(plugin: Optional[PLPrecisionPlugin]) -> LitePrecision:
284284
if type(plugin) is PLPrecisionPlugin:
285285
return LitePrecision()
286286

287-
if type(plugin) is PLNativeMixedPrecisionPlugin:
287+
if type(plugin) is PLMixedPrecisionPlugin:
288288
return LiteMixedPrecision(
289289
precision=plugin.precision, device=plugin.device, scaler=plugin.scaler # type: ignore[arg-type]
290290
)

src/pytorch_lightning/plugins/precision/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
2020
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
2121
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
22-
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
22+
from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin, NativeMixedPrecisionPlugin
2323
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2424
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
2525
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
@@ -35,6 +35,7 @@
3535
"HPUPrecisionPlugin",
3636
"IPUPrecisionPlugin",
3737
"NativeMixedPrecisionPlugin",
38+
"MixedPrecisionPlugin",
3839
"PrecisionPlugin",
3940
"ShardedNativeMixedPrecisionPlugin",
4041
"TPUPrecisionPlugin",

src/pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Optional, Union
1515

1616
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
17-
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
17+
from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin
1818
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1919

2020
if _FAIRSCALE_AVAILABLE:
@@ -24,7 +24,7 @@
2424
OSS = ShardedGradScaler = object
2525

2626

27-
class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
27+
class ShardedNativeMixedPrecisionPlugin(MixedPrecisionPlugin):
2828
"""Native AMP for Sharded Training."""
2929

3030
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from lightning_lite.utilities.cloud_io import get_filesystem
3030
from lightning_lite.utilities.types import _PATH
3131
from pytorch_lightning.callbacks import ModelCheckpoint
32-
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
32+
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, MixedPrecisionPlugin
3333
from pytorch_lightning.trainer.states import TrainerFn
3434
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -295,9 +295,7 @@ def restore_precision_plugin_state(self) -> None:
295295
# old checkpoints compatibility
296296
if "amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, ApexMixedPrecisionPlugin):
297297
prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"])
298-
if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(
299-
prec_plugin, NativeMixedPrecisionPlugin
300-
):
298+
if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, MixedPrecisionPlugin):
301299
prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"])
302300

303301
def _restore_quantization_callbacks(self) -> None:

tests/tests_lite/plugins/precision/test_deepspeed.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,11 @@
2323

2424
def test_invalid_precision_with_deepspeed_precision():
2525
with pytest.raises(ValueError, match="is not supported in DeepSpeed. `precision` must be one of"):
26-
DeepSpeedPrecision(precision=64, amp_type="native")
27-
28-
29-
def test_deepspeed_precision_apex_not_installed(monkeypatch):
30-
import lightning_lite.plugins.precision.deepspeed as deepspeed
31-
32-
monkeypatch.setattr(deepspeed, "_APEX_AVAILABLE", False)
33-
with pytest.raises(ImportError, match="You have asked for Apex AMP but `apex` is not installed."):
34-
DeepSpeedPrecision(precision=16, amp_type="apex")
35-
36-
37-
@mock.patch("lightning_lite.plugins.precision.deepspeed._APEX_AVAILABLE", return_value=True)
38-
def test_deepspeed_precision_apex_default_level(_):
39-
with pytest.deprecated_call(match="apex AMP implementation has been deprecated"):
40-
precision = DeepSpeedPrecision(precision=16, amp_type="apex")
41-
assert isinstance(precision, DeepSpeedPrecision)
42-
assert precision.amp_level == "O2"
26+
DeepSpeedPrecision(precision=64)
4327

4428

4529
def test_deepspeed_precision_backward():
46-
precision = DeepSpeedPrecision(precision=32, amp_type="native")
30+
precision = DeepSpeedPrecision(precision=32)
4731
tensor = Mock()
4832
model = Mock()
4933
precision.backward(tensor, model, "positional-arg", keyword="arg")
@@ -61,7 +45,7 @@ def test_deepspeed_engine_is_steppable(engine):
6145

6246

6347
def test_deepspeed_precision_optimizer_step():
64-
precision = DeepSpeedPrecision(precision=32, amp_type="native")
48+
precision = DeepSpeedPrecision(precision=32)
6549
optimizer = model = Mock()
6650
precision.optimizer_step(optimizer, lr_kwargs=dict())
6751
model.step.assert_called_once_with(lr_kwargs=dict())

tests/tests_pytorch/deprecated_api/test_remove_1-10.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def test_profiler_classes_deprecated_warning(cls):
355355
cls()
356356

357357

358+
@RunIf(amp_apex=True)
358359
def test_apex_deprecation_warnings():
359360
class MyModel(BoringModel):
360361
def optimizer_step(

tests/tests_pytorch/models/test_ddp_fork_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torch
1717

18-
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
18+
from pytorch_lightning.plugins import MixedPrecisionPlugin
1919
from tests_pytorch.helpers.runif import RunIf
2020

2121

@@ -24,7 +24,7 @@
2424
def test_amp_gpus_ddp_fork():
2525
"""Ensure the use of native AMP with `ddp_fork` (or associated alias strategies) does not generate CUDA
2626
initialization errors."""
27-
_ = NativeMixedPrecisionPlugin(precision=16, device="cuda")
27+
_ = MixedPrecisionPlugin(precision=16, device="cuda")
2828
with multiprocessing.get_context("fork").Pool(1) as pool:
2929
in_bad_fork = pool.apply(torch.cuda._is_in_bad_fork)
3030
assert not in_bad_fork

tests/tests_pytorch/models/test_hooks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,14 +518,15 @@ def training_step(self, batch, batch_idx):
518518
"state_dict": ANY,
519519
"loops": ANY,
520520
}
521-
if kwargs.get("precision") == 16:
521+
using_deepspeed = kwargs.get("strategy") == "deepspeed"
522+
if kwargs.get("precision") == 16 and not using_deepspeed:
522523
saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY
523524
device = torch.device("cuda:0" if "accelerator" in kwargs and kwargs["accelerator"] == "gpu" else "cpu")
524525
expected = [
525526
dict(name="configure_callbacks"),
526527
dict(name="prepare_data"),
527528
# DeepSpeed needs the batch size to figure out throughput logging
528-
*([dict(name="train_dataloader")] if kwargs.get("strategy") == "deepspeed" else []),
529+
*([dict(name="train_dataloader")] if using_deepspeed else []),
529530
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
530531
dict(name="setup", kwargs=dict(stage="fit")),
531532
dict(name="configure_sharded_model"),

tests/tests_pytorch/strategies/test_sharded_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE
1111
from pytorch_lightning import LightningModule, Trainer
1212
from pytorch_lightning.demos.boring_classes import BoringModel
13-
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
13+
from pytorch_lightning.plugins import MixedPrecisionPlugin
1414
from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy
1515
from pytorch_lightning.trainer.states import TrainerFn
1616
from tests_pytorch.helpers.runif import RunIf
@@ -91,7 +91,7 @@ def test_ddp_choice_sharded_amp(strategy, expected):
9191
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
9292
trainer = Trainer(fast_dev_run=True, accelerator="gpu", devices=1, precision=16, strategy=strategy)
9393
assert isinstance(trainer.strategy, expected)
94-
assert isinstance(trainer.precision_plugin, NativeMixedPrecisionPlugin)
94+
assert isinstance(trainer.precision_plugin, MixedPrecisionPlugin)
9595

9696

9797
@RunIf(fairscale=True)

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,9 @@ def test_validate_precision_type(precision):
432432

433433

434434
def test_amp_level_raises_error_with_native():
435-
with pytest.deprecated_call(
436-
match="Setting `amp_level` inside the `Trainer` is deprecated in v1.8.0"
437-
), pytest.raises(MisconfigurationException, match="O2'` but it's only supported with `amp_backend='apex'`"):
435+
with pytest.deprecated_call(match="apex AMP implementation has been deprecated"), pytest.raises(
436+
MisconfigurationException, match="O2'` but it's only supported with `amp_backend='apex'`"
437+
):
438438
_ = Trainer(amp_level="O2", amp_backend="native", precision=16)
439439

440440

0 commit comments

Comments
 (0)