Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/lightning/fabric/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
Args:
seed: the integer value seed for global random state in Lightning.
If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0.
``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. If the seed is provided
but is not in bounds or cannot be cast to int, a ValueError is raised.
workers: if set to ``True``, will properly configure all dataloaders passed to the
Trainer with a ``worker_init_fn``. If the user already provides such a function
for their dataloaders, setting this argument will have no influence. See also:
Expand All @@ -44,14 +45,12 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
try:
seed = int(env_seed)
except ValueError:
seed = 0
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
raise ValueError(f"Invalid seed specified via PL_GLOBAL_SEED: {repr(env_seed)}")
elif not isinstance(seed, int):
seed = int(seed)

if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = 0
raise ValueError(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")

if verbose:
log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank()))
Expand Down
14 changes: 6 additions & 8 deletions tests/tests_fabric/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,17 @@ def test_correct_seed_with_environment_variable():

@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="Invalid seed found"):
seed = seed_everything()
assert seed == 0
"""Ensure that a ValueError is raised if an invalid seed is given."""
with pytest.raises(ValueError, match="Invalid seed specified"):
seed_everything()


@mock.patch.dict(os.environ, {}, clear=True)
@pytest.mark.parametrize("seed", [10e9, -10e9])
def test_out_of_bounds_seed(seed):
"""Ensure that we still fix the seed even if an out-of-bounds seed is given."""
with pytest.warns(UserWarning, match="is not in bounds"):
actual = seed_everything(seed)
assert actual == 0
"""Ensure that a ValueError is raised if an out-of-bounds seed is given."""
with pytest.raises(ValueError, match="is not in bounds"):
seed_everything(seed)


def test_reset_seed_no_op():
Expand Down
Loading