Skip to content

Commit d90f624

Browse files
jerome-habanaBorda
authored andcommitted
Enable back inference mode support with hpu & update links (#15918)
* Enable back inference mode support with hpu * Remove unused * Update document link and address comment Signed-off-by: Jerome <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 6aaac8b)
1 parent a528d56 commit d90f624

File tree

5 files changed

+6
-13
lines changed

5 files changed

+6
-13
lines changed

docs/source-pytorch/accelerators/hpu_basic.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,3 @@ Known limitations
113113
-----------------
114114

115115
* `Habana dataloader <https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html#habana-data-loader>`__ is not supported.
116-
* :func:`torch.inference_mode` is not supported

docs/source-pytorch/accelerators/hpu_intermediate.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,4 @@ The below snippet shows how DeviceStatsMonitor can be enabled.
9696
device_stats = DeviceStatsMonitor()
9797
trainer = Trainer(accelerator="hpu", callbacks=[device_stats])
9898
99-
For more details, please refer to `Memory Stats APIs <https://docs.habana.ai/en/v1.5.0/PyTorch/PyTorch_User_Guide/Python_Packages.html#memory-stats-apis>`__.
99+
For more details, please refer to `Memory Stats APIs <https://docs.habana.ai/en/latest/PyTorch/PyTorch_User_Guide/Python_Packages.html#memory-stats-apis>`__.

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77
## [1.8.4] - 2022-12-06
88

99
- Direct support for compiled models ([#15922](https://github.com/Lightning-AI/lightning/pull/15922))
10+
- Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918))
1011

1112

1213
## [1.8.3] - 2022-11-22

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,7 @@
8080
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
8181
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
8282
from pytorch_lightning.utilities.exceptions import MisconfigurationException
83-
from pytorch_lightning.utilities.imports import (
84-
_HOROVOD_AVAILABLE,
85-
_HPU_AVAILABLE,
86-
_IPU_AVAILABLE,
87-
_TORCH_GREATER_EQUAL_1_11,
88-
)
83+
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _IPU_AVAILABLE, _TORCH_GREATER_EQUAL_1_11
8984
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
9085

9186
log = logging.getLogger(__name__)
@@ -499,7 +494,7 @@ def _choose_auto_accelerator(self) -> str:
499494
return "tpu"
500495
if _IPU_AVAILABLE:
501496
return "ipu"
502-
if _HPU_AVAILABLE:
497+
if HPUAccelerator.is_available():
503498
return "hpu"
504499
if MPSAccelerator.is_available():
505500
return "mps"

src/pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from lightning_lite.utilities.data import _auto_add_worker_init_fn
4949
from lightning_lite.utilities.types import _PATH
5050
from lightning_lite.utilities.warnings import PossibleUserWarning
51-
from pytorch_lightning.accelerators import Accelerator, HPUAccelerator, TPUAccelerator
51+
from pytorch_lightning.accelerators import Accelerator, TPUAccelerator
5252
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
5353
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
5454
from pytorch_lightning.core.datamodule import LightningDataModule
@@ -2265,13 +2265,11 @@ def configure_optimizers(self):
22652265

22662266
@contextmanager
22672267
def _evaluation_context(accelerator: Accelerator, inference_mode: bool = True) -> Generator:
2268-
# inference mode is not supported with gloo backend (#9431),
2269-
# and HPU & TPU accelerators.
2268+
# inference mode is not supported with gloo backend (#9431) and TPU accelerators.
22702269
context_manager_class = (
22712270
torch.inference_mode
22722271
if inference_mode
22732272
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
2274-
and not isinstance(accelerator, HPUAccelerator)
22752273
and not isinstance(accelerator, TPUAccelerator)
22762274
else torch.no_grad
22772275
)

0 commit comments

Comments
 (0)