Skip to content

Commit b0c7908

Browse files
awaelchlilexierule
authored andcommitted
Incorporate pytorch's fixes in device_count_nvml #16795
1 parent 058dd9f commit b0c7908

File tree

4 files changed

+132
-29
lines changed

4 files changed

+132
-29
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Fixed
1111

12+
- Fixed edge cases in parsing device ids using NVML ([#16795](https://github.com/Lightning-AI/lightning/pull/16795))
1213
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
1314

1415

src/lightning_fabric/accelerators/cuda.py

Lines changed: 129 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515
import warnings
1616
from contextlib import contextmanager
1717
from functools import lru_cache
18-
from typing import Dict, Generator, List, Optional, Set, Union
18+
from typing import cast, Dict, Generator, List, Optional, Union
1919

2020
import torch
2121
from lightning_utilities.core.rank_zero import rank_zero_info
2222

2323
from lightning_fabric.accelerators.accelerator import Accelerator
24-
from lightning_fabric.utilities.imports import (
25-
_TORCH_GREATER_EQUAL_1_12,
26-
_TORCH_GREATER_EQUAL_1_13,
27-
_TORCH_GREATER_EQUAL_2_0,
28-
)
24+
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
2925

3026

3127
class CUDAAccelerator(Accelerator):
@@ -161,11 +157,11 @@ def num_cuda_devices() -> int:
161157
Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
162158
if the platform allows it.
163159
"""
164-
if _TORCH_GREATER_EQUAL_1_13:
160+
if _TORCH_GREATER_EQUAL_2_0:
165161
return torch.cuda.device_count()
166162

167163
# Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
168-
# TODO: Remove once minimum supported PyTorch version is 1.13
164+
# TODO: Remove once minimum supported PyTorch version is 2.0
169165
nvml_count = _device_count_nvml()
170166
return torch.cuda.device_count() if nvml_count < 0 else nvml_count
171167

@@ -180,63 +176,167 @@ def is_cuda_available() -> bool:
180176
return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices() > 0
181177

182178

183-
# TODO: Remove once minimum supported PyTorch version is 1.13
184-
def _parse_visible_devices() -> Set[int]:
185-
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
179+
# TODO: Remove once minimum supported PyTorch version is 2.0
180+
def _parse_visible_devices() -> Union[List[int], List[str]]:
181+
"""Parse CUDA_VISIBLE_DEVICES environment variable."""
186182
var = os.getenv("CUDA_VISIBLE_DEVICES")
187183
if var is None:
188-
return {x for x in range(64)}
184+
return list(range(64))
189185

190186
def _strtoul(s: str) -> int:
191-
"""Return -1 or integer sequence string starts with."""
192-
if len(s) == 0:
187+
"""Return -1 or positive integer sequence string starts with,"""
188+
if not s:
193189
return -1
194190
for idx, c in enumerate(s):
195-
if not c.isdigit():
191+
if not (c.isdigit() or (idx == 0 and c in "+-")):
196192
break
197193
if idx + 1 == len(s):
198194
idx += 1
199195
return int(s[:idx]) if idx > 0 else -1
200196

197+
def parse_list_with_prefix(lst: str, prefix: str) -> List[str]:
198+
rcs: List[str] = []
199+
for elem in lst.split(","):
200+
# Repeated id results in empty set
201+
if elem in rcs:
202+
return cast(List[str], [])
203+
# Anything other but prefix is ignored
204+
if not elem.startswith(prefix):
205+
break
206+
rcs.append(elem)
207+
return rcs
208+
209+
if var.startswith("GPU-"):
210+
return parse_list_with_prefix(var, "GPU-")
211+
if var.startswith("MIG-"):
212+
return parse_list_with_prefix(var, "MIG-")
201213
# CUDA_VISIBLE_DEVICES uses something like strtoul
202214
# which makes `1gpu2,2ampere` is equivalent to `1,2`
203-
rc: Set[int] = set()
215+
rc: List[int] = []
204216
for elem in var.split(","):
205-
rc.add(_strtoul(elem.strip()))
217+
x = _strtoul(elem.strip())
218+
# Repeated ordinal results in empty set
219+
if x in rc:
220+
return cast(List[int], [])
221+
# Negative value aborts the sequence
222+
if x < 0:
223+
break
224+
rc.append(x)
206225
return rc
207226

208227

209-
# TODO: Remove once minimum supported PyTorch version is 1.13
228+
# TODO: Remove once minimum supported PyTorch version is 2.0
210229
def _raw_device_count_nvml() -> int:
211-
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
212-
from ctypes import c_int, CDLL
230+
"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed."""
231+
from ctypes import byref, c_int, CDLL
213232

214233
nvml_h = CDLL("libnvidia-ml.so.1")
215234
rc = nvml_h.nvmlInit()
216235
if rc != 0:
217236
warnings.warn("Can't initialize NVML")
218237
return -1
219-
dev_arr = (c_int * 1)(-1)
220-
rc = nvml_h.nvmlDeviceGetCount_v2(dev_arr)
238+
dev_count = c_int(-1)
239+
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
221240
if rc != 0:
222241
warnings.warn("Can't get nvml device count")
223242
return -1
224243
del nvml_h
225-
return dev_arr[0]
244+
return dev_count.value
245+
226246

247+
# TODO: Remove once minimum supported PyTorch version is 2.0
248+
def _raw_device_uuid_nvml() -> Optional[List[str]]:
249+
"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
250+
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer
227251

228-
# TODO: Remove once minimum supported PyTorch version is 1.13
252+
nvml_h = CDLL("libnvidia-ml.so.1")
253+
rc = nvml_h.nvmlInit()
254+
if rc != 0:
255+
warnings.warn("Can't initialize NVML")
256+
return None
257+
dev_count = c_int(-1)
258+
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
259+
if rc != 0:
260+
warnings.warn("Can't get nvml device count")
261+
return None
262+
uuids: List[str] = []
263+
for idx in range(dev_count.value):
264+
dev_id = c_void_p()
265+
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
266+
if rc != 0:
267+
warnings.warn("Can't get device handle")
268+
return None
269+
buf_len = 96
270+
buf = create_string_buffer(buf_len)
271+
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
272+
if rc != 0:
273+
warnings.warn("Can't get device UUID")
274+
return None
275+
uuids.append(buf.raw.decode("ascii").strip("\0"))
276+
del nvml_h
277+
return uuids
278+
279+
280+
# TODO: Remove once minimum supported PyTorch version is 2.0
281+
def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]:
282+
"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials
283+
IDs."""
284+
285+
def uuid_to_orinal(candidate: str, uuids: List[str]) -> int:
286+
best_match = -1
287+
for idx, uuid in enumerate(uuids):
288+
if not uuid.startswith(candidate):
289+
continue
290+
# Ambigous candidate
291+
if best_match != -1:
292+
return -1
293+
best_match = idx
294+
return best_match
295+
296+
rc: List[int] = []
297+
for candidate in candidates:
298+
idx = uuid_to_orinal(candidate, uuids)
299+
# First invalid ordinal stops parsing
300+
if idx < 0:
301+
break
302+
# Duplicates result in empty set
303+
if idx in rc:
304+
return cast(List[int], [])
305+
rc.append(idx)
306+
return rc
307+
308+
309+
# TODO: Remove once minimum supported PyTorch version is 2.0
229310
def _device_count_nvml() -> int:
230-
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
311+
"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
312+
313+
Negative value is returned if NVML discovery or initialization has failed.
314+
"""
315+
visible_devices = _parse_visible_devices()
316+
if not visible_devices:
317+
return 0
231318
try:
232-
raw_cnt = _raw_device_count_nvml()
233-
if raw_cnt <= 0:
234-
return raw_cnt
235-
return len(set(range(raw_cnt)).intersection(_parse_visible_devices()))
319+
if type(visible_devices[0]) is str:
320+
# Skip MIG parsing
321+
if visible_devices[0].startswith("MIG-"):
322+
return -1
323+
uuids = _raw_device_uuid_nvml()
324+
if uuids is None:
325+
return -1
326+
visible_devices = _transform_uuid_to_ordinals(cast(List[str], visible_devices), uuids)
327+
else:
328+
raw_cnt = _raw_device_count_nvml()
329+
if raw_cnt <= 0:
330+
return raw_cnt
331+
# Trim the list up to a maximum available device
332+
for idx, val in enumerate(visible_devices):
333+
if cast(int, val) >= raw_cnt:
334+
return idx
236335
except OSError:
237336
return -1
238337
except AttributeError:
239338
return -1
339+
return len(visible_devices)
240340

241341

242342
def _check_cuda_matmul_precision(device: torch.device) -> None:

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Fixed
1111

1212
- Fixed DDP spawn hang on TPU Pods ([#16844](https://github.com/Lightning-AI/lightning/pull/16844))
13+
- Fixed edge cases in parsing device ids using NVML ([#16795](https://github.com/Lightning-AI/lightning/pull/16795))
1314

1415

1516
## [1.9.3] - 2023-02-21

tests/tests_fabric/plugins/collectives/test_torch_collective.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def _test_distributed_collectives_fn(strategy, collective):
233233

234234
@skip_distributed_unavailable
235235
@pytest.mark.parametrize("n", (1, 2))
236+
@RunIf(skip_windows=True)
236237
@mock.patch.dict(os.environ, os.environ.copy(), clear=True) # sets CUDA_MODULE_LOADING in torch==1.13
237238
def test_collectives_distributed(n):
238239
collective_launch(_test_distributed_collectives_fn, [torch.device("cpu")] * n)

0 commit comments

Comments
 (0)