Skip to content

Commit fd8dbf6

Browse files
committed
Request torch.cuda RNG states only if CUDA is available (#19234)
1 parent f04f9f6 commit fd8dbf6

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

src/lightning/fabric/utilities/seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
115115
"python": python_get_rng_state(),
116116
}
117117
if include_cuda:
118-
states["torch.cuda"] = torch.cuda.get_rng_state_all()
118+
states["torch.cuda"] = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else []
119119
return states
120120

121121

tests/tests_fabric/utilities/test_seed.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from unittest import mock
3+
from unittest.mock import Mock
34

45
import lightning.fabric.utilities
56
import pytest
@@ -81,3 +82,12 @@ def test_backward_compatibility_rng_states_dict():
8182
assert "torch.cuda" in states
8283
states.pop("torch.cuda")
8384
_set_rng_states(states)
85+
86+
87+
@mock.patch("lightning.fabric.utilities.seed.torch.cuda.is_available", Mock(return_value=False))
88+
@mock.patch("lightning.fabric.utilities.seed.torch.cuda.get_rng_state_all")
89+
def test_collect_rng_states_if_cuda_init_fails(get_rng_state_all_mock):
90+
"""Test that the `torch.cuda` rng states are only requested if CUDA is available."""
91+
get_rng_state_all_mock.side_effect = RuntimeError("The NVIDIA driver on your system is too old")
92+
states = _collect_rng_states()
93+
assert states["torch.cuda"] == []

tests/tests_pytorch/utilities/test_seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,5 @@ def test_isolate_rng_cuda(get_cuda_rng, set_cuda_rng):
4747
set_cuda_rng.assert_not_called()
4848

4949
with isolate_rng(include_cuda=True):
50-
get_cuda_rng.assert_called_once()
50+
assert get_cuda_rng.call_count == int(torch.cuda.is_available())
5151
set_cuda_rng.assert_called_once()

0 commit comments

Comments
 (0)