Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Calling a method other than `forward` that invokes submodules is now an error when the model is wrapped (e.g., with DDP) ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))


- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))


### Deprecated

Expand Down
18 changes: 7 additions & 11 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@


def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
r"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
sets the following environment variables:
r"""Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python's random module.
In addition, sets the following environment variables:

- ``PL_GLOBAL_SEED``: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
- ``PL_SEED_WORKERS``: (optional) is set to 1 if ``workers=True``.

Args:
seed: the integer value seed for global random state in Lightning.
If ``None``, will read seed from ``PL_GLOBAL_SEED`` env variable
or select it randomly.
If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0.
workers: if set to ``True``, will properly configure all dataloaders passed to the
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
Expand All @@ -36,20 +36,20 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
if seed is None:
env_seed = os.environ.get("PL_GLOBAL_SEED")
if env_seed is None:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
seed = 0
rank_zero_warn(f"No seed found, seed set to {seed}")
else:
try:
seed = int(env_seed)
except ValueError:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
seed = 0
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
elif not isinstance(seed, int):
seed = int(seed)

if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = _select_seed_randomly(min_seed_value, max_seed_value)
seed = 0

log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
os.environ["PL_GLOBAL_SEED"] = str(seed)
Expand All @@ -63,10 +63,6 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
return seed


def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int:
return random.randint(min_seed_value, max_seed_value) # noqa: S311


def reset_seed() -> None:
r"""Reset the seed to the value that :func:`~lightning.fabric.utilities.seed.seed_everything` previously set.

Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
- `LightningCLI` no longer allows setting a normal class instance as default. A `lazy_instance` can be used instead ([#18822](https://github.com/Lightning-AI/lightning/pull/18822))


Expand Down
18 changes: 11 additions & 7 deletions tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
import lightning.fabric.utilities
import pytest
import torch
from lightning.fabric.utilities import seed as seed_utils
from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states


@mock.patch.dict(os.environ, clear=True)
def test_default_seed():
"""Test that the default seed is 0 when no seed provided and no environment variable set."""
assert lightning.fabric.utilities.seed.seed_everything() == 0
assert os.environ["PL_GLOBAL_SEED"] == "0"


@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
Expand All @@ -30,22 +36,20 @@ def test_correct_seed_with_environment_variable():


@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
@mock.patch.object(seed_utils, attribute="_select_seed_randomly", return_value=123)
def test_invalid_seed(_):
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="Invalid seed found"):
seed = lightning.fabric.utilities.seed.seed_everything()
assert seed == 123
assert seed == 0


@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch.object(seed_utils, attribute="_select_seed_randomly", return_value=123)
@pytest.mark.parametrize("seed", [10e9, -10e9])
def test_out_of_bounds_seed(_, seed):
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = lightning.fabric.utilities.seed.seed_everything(seed)
assert actual == 123
assert actual == 0


def test_reset_seed_no_op():
Expand Down