Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from lightning_lite.utilities.rank_zero import rank_zero_warn

from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/plugins/environments/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
import sys
from typing import Optional

from lightning_utilities.core.rank_zero import rank_zero_warn

from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.imports import _IS_WINDOWS
from lightning_lite.utilities.rank_zero import rank_zero_warn
from lightning_lite.utilities.warnings import PossibleUserWarning

log = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_only
from torch.nn import Module
from torch.optim import Optimizer

Expand All @@ -33,7 +32,7 @@
from lightning_lite.strategies.strategy import _Sharded
from lightning_lite.utilities.distributed import log
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.rank_zero import rank_zero_info
from lightning_lite.utilities.rank_zero import rank_zero_info, rank_zero_only
from lightning_lite.utilities.seed import reset_seed
from lightning_lite.utilities.types import _PATH

Expand Down
2 changes: 2 additions & 0 deletions src/lightning_lite/utilities/rank_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

# note: we want to keep these indirections so the `rank_zero_only.rank` is set on import
from lightning_utilities.core.rank_zero import ( # noqa: F401
rank_prefixed_message,
rank_zero_debug,
rank_zero_deprecation,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
WarningCache,
)

import lightning_lite
Expand Down
3 changes: 1 addition & 2 deletions src/lightning_lite/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import numpy as np
import torch
from lightning_utilities.core.rank_zero import rank_prefixed_message

from lightning_lite.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn
from lightning_lite.utilities.rank_zero import _get_rank, rank_prefixed_message, rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import Any, Dict

import torch
from lightning_utilities.core.rank_zero import rank_zero_deprecation

import pytorch_lightning as pl
from lightning_lite.accelerators.accelerator import Accelerator as _Accelerator
from lightning_lite.utilities.types import _DEVICE
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class Accelerator(_Accelerator, ABC):
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@

import numpy as np
import torch
from lightning_utilities.core.rank_zero import rank_prefixed_message
from torch import Tensor

import pytorch_lightning as pl
from lightning_lite.utilities.rank_zero import _get_rank
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@
import numpy as np
import torch
import yaml
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.types import _PATH
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.types import STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.rank_zero import WarningCache
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand All @@ -49,7 +48,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import (
_METRIC_COLLECTION,
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from abc import ABC
from typing import List, Optional, Tuple, Union

from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn

from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT
from lightning_lite.connector import _PRECISION_INPUT
from lightning_lite.lite import LightningLite as _NewLightningLite
Expand Down Expand Up @@ -52,6 +50,7 @@
from pytorch_lightning.strategies import SingleTPUStrategy as PLSingleTPUStrategy
from pytorch_lightning.strategies import Strategy as PLStrategy
from pytorch_lightning.strategies import TPUSpawnStrategy as PLTPUSpawnStrategy
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

_PL_PLUGIN = Union[PLPrecisionPlugin, ClusterEnvironment, CheckpointIO]
_PL_PLUGIN_INPUT = Union[_PL_PLUGIN, str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, Iterator, List, Tuple

import torch
from lightning_utilities.core.rank_zero import WarningCache

from lightning_lite.utilities import move_data_to_device
from pytorch_lightning.loops.loop import Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.utilities.rank_zero import WarningCache

warning_cache = WarningCache()

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import numpy as np
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.rank_zero import WarningCache

import pytorch_lightning as pl
from pytorch_lightning import loops # import as loops to avoid circular imports
Expand All @@ -32,7 +31,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import OrderedDict
Expand All @@ -34,7 +33,7 @@
from pytorch_lightning.plugins.precision.native_amp import MixedPrecisionPlugin
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
from typing import Any, Callable, Optional, Union

from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.rank_zero import WarningCache

warning_cache = WarningCache()

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import LBFGS, Optimizer

Expand All @@ -26,7 +25,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, WarningCache

_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from typing import Any, Callable, Union

from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import LBFGS, Optimizer

Expand All @@ -24,6 +23,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import WarningCache

warning_cache = WarningCache()

Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/profilers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.rank_zero import WarningCache
from torch import nn, Tensor
from torch.autograd.profiler import record_function

from lightning_lite.accelerators.cuda import is_cuda_available
from pytorch_lightning.profilers.profiler import Profiler
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache

if TYPE_CHECKING:
from torch.autograd.profiler import EventList
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand All @@ -35,6 +34,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import STEP_OUTPUT

_COLOSSALAI_AVAILABLE = RequirementCache("colossalai")
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
Expand All @@ -45,7 +44,7 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.types import LRSchedulerConfig, STEP_OUTPUT

log = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from weakref import proxy

from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.rank_zero import WarningCache
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

Expand All @@ -35,7 +34,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torchmetrics import Metric
from typing_extensions import TypedDict
Expand All @@ -30,7 +29,7 @@
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.warnings import PossibleUserWarning

_IN_METRIC = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from enum import Enum, EnumMeta
from typing import Any, List, Optional

from lightning_utilities.core.rank_zero import rank_zero_deprecation

from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class _DeprecationManagingEnumMeta(EnumMeta):
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from lightning_utilities.core.apply_func import is_dataclass_instance
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.utils.data import (
BatchSampler,
Expand All @@ -39,7 +38,7 @@
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn, WarningCache

# might be supported in later releases, see https://github.com/python/mypy/pull/13297
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_lightning/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
import re
from typing import Any, Callable, Dict, List

from lightning_utilities.core.rank_zero import rank_zero_warn

from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

_CHECKPOINT = Dict[str, Any]

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/migration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from types import ModuleType, TracebackType
from typing import Any, Dict, List, Optional, Tuple, Type

from lightning_utilities.core.rank_zero import rank_zero_warn
from packaging.version import Version

import pytorch_lightning as pl
from lightning_lite.utilities.imports import _IS_WINDOWS
from lightning_lite.utilities.types import _PATH
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.utilities.migration.migration import _migration_index
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

_log = logging.getLogger(__name__)
_CHECKPOINT = Dict[str, Any]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import numpy as np
import torch
import torch.nn as nn
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.utils.hooks import RemovableHandle

import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import WarningCache

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down
Loading