Skip to content

Commit bda70a2

Browse files
authored
Integrate lightning_utilities get_all_subclasses (#14575)
1 parent 8c4184c commit bda70a2

File tree

4 files changed

+14
-24
lines changed

4 files changed

+14
-24
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Integrate the `lightning_utilities` package (
2828
[#14475](https://github.com/Lightning-AI/lightning/issues/14475),
2929
[#14537](https://github.com/Lightning-AI/lightning/issues/14537),
30-
[#14556](https://github.com/Lightning-AI/lightning/issues/14556))
30+
[#14556](https://github.com/Lightning-AI/lightning/issues/14556),
31+
[#14558](https://github.com/Lightning-AI/lightning/issues/14558),
32+
[#14575](https://github.com/Lightning-AI/lightning/issues/14575))
3133

3234

3335
### Changed

src/pytorch_lightning/utilities/cli.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from typing import Any, Generator, List, Optional, Tuple, Type
1919

2020
import torch
21+
from lightning_utilities.core.inheritance import get_all_subclasses
2122
from torch.optim import Optimizer
2223

2324
import pytorch_lightning as pl
2425
import pytorch_lightning.cli as new_cli
25-
from pytorch_lightning.utilities.meta import _get_all_subclasses
2626
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2727

2828
_deprecate_registry_message = (
@@ -108,17 +108,17 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
108108
if subclasses:
109109
rank_zero_deprecation(_deprecate_auto_registry_message)
110110
# this will register any subclasses from all loaded modules including userland
111-
for cls in _get_all_subclasses(torch.optim.Optimizer):
111+
for cls in get_all_subclasses(torch.optim.Optimizer):
112112
OPTIMIZER_REGISTRY(cls, show_deprecation=False)
113-
for cls in _get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
113+
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
114114
LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
115-
for cls in _get_all_subclasses(pl.Callback):
115+
for cls in get_all_subclasses(pl.Callback):
116116
CALLBACK_REGISTRY(cls, show_deprecation=False)
117-
for cls in _get_all_subclasses(pl.LightningModule):
117+
for cls in get_all_subclasses(pl.LightningModule):
118118
MODEL_REGISTRY(cls, show_deprecation=False)
119-
for cls in _get_all_subclasses(pl.LightningDataModule):
119+
for cls in get_all_subclasses(pl.LightningDataModule):
120120
DATAMODULE_REGISTRY(cls, show_deprecation=False)
121-
for cls in _get_all_subclasses(pl.loggers.Logger):
121+
for cls in get_all_subclasses(pl.loggers.Logger):
122122
LOGGER_REGISTRY(cls, show_deprecation=False)
123123
else:
124124
# manually register torch's subclasses and our subclasses

src/pytorch_lightning/utilities/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import torch
2424
from lightning_utilities.core.apply_func import is_dataclass_instance
25+
from lightning_utilities.core.inheritance import get_all_subclasses
2526
from lightning_utilities.core.rank_zero import WarningCache
2627
from torch import Tensor
2728
from torch.utils.data import (
@@ -40,7 +41,6 @@
4041
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
4142
from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum
4243
from pytorch_lightning.utilities.exceptions import MisconfigurationException
43-
from pytorch_lightning.utilities.meta import _get_all_subclasses
4444
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
4545
from pytorch_lightning.utilities.seed import pl_worker_init_function
4646

@@ -549,7 +549,7 @@ def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] =
549549
550550
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods.
551551
"""
552-
classes = _get_all_subclasses(base_cls) | {base_cls}
552+
classes = get_all_subclasses(base_cls) | {base_cls}
553553
for cls in classes:
554554
# Check that __init__ belongs to the class
555555
# https://stackoverflow.com/a/5253424

src/pytorch_lightning/utilities/meta.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,9 @@ def get_all_subclasses(cls: Type) -> Set[Type]:
4343
"`pytorch_lightning.utilities.meta.get_all_subclasses` is deprecated in v1.8 and will be removed in v1.9."
4444
" Please copy its implementation if you have a use for it."
4545
)
46-
return _get_all_subclasses(cls)
46+
from lightning_utilities.core.inheritance import get_all_subclasses as new_get_all_subclasses
4747

48-
49-
# https://stackoverflow.com/a/63851681/9201239
50-
def _get_all_subclasses(cls: Type) -> Set[Type]:
51-
subclass_list = []
52-
53-
def recurse(cl: Type) -> None:
54-
for subclass in cl.__subclasses__():
55-
subclass_list.append(subclass)
56-
recurse(subclass)
57-
58-
recurse(cls)
59-
60-
return set(subclass_list)
48+
return new_get_all_subclasses(cls)
6149

6250

6351
def recursively_setattr(root_module: Any, prefix: str, materialized_module: Module) -> None:

0 commit comments

Comments
 (0)