15
15
import warnings
16
16
from contextlib import contextmanager
17
17
from functools import lru_cache
18
- from typing import Dict , Generator , List , Optional , Set , Union
18
+ from typing import cast , Dict , Generator , List , Optional , Union
19
19
20
20
import torch
21
21
from lightning_utilities .core .rank_zero import rank_zero_info
22
22
23
23
from lightning_fabric .accelerators .accelerator import Accelerator
24
- from lightning_fabric .utilities .imports import (
25
- _TORCH_GREATER_EQUAL_1_12 ,
26
- _TORCH_GREATER_EQUAL_1_13 ,
27
- _TORCH_GREATER_EQUAL_2_0 ,
28
- )
24
+ from lightning_fabric .utilities .imports import _TORCH_GREATER_EQUAL_1_12 , _TORCH_GREATER_EQUAL_2_0
29
25
30
26
31
27
class CUDAAccelerator (Accelerator ):
@@ -161,11 +157,11 @@ def num_cuda_devices() -> int:
161
157
Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
162
158
if the platform allows it.
163
159
"""
164
- if _TORCH_GREATER_EQUAL_1_13 :
160
+ if _TORCH_GREATER_EQUAL_2_0 :
165
161
return torch .cuda .device_count ()
166
162
167
163
# Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
168
- # TODO: Remove once minimum supported PyTorch version is 1.13
164
+ # TODO: Remove once minimum supported PyTorch version is 2.0
169
165
nvml_count = _device_count_nvml ()
170
166
return torch .cuda .device_count () if nvml_count < 0 else nvml_count
171
167
@@ -180,63 +176,167 @@ def is_cuda_available() -> bool:
180
176
return torch .cuda .is_available () if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices () > 0
181
177
182
178
183
- # TODO: Remove once minimum supported PyTorch version is 1.13
184
- def _parse_visible_devices () -> Set [ int ]:
185
- """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 ."""
179
+ # TODO: Remove once minimum supported PyTorch version is 2.0
180
+ def _parse_visible_devices () -> Union [ List [ int ], List [ str ] ]:
181
+ """Parse CUDA_VISIBLE_DEVICES environment variable ."""
186
182
var = os .getenv ("CUDA_VISIBLE_DEVICES" )
187
183
if var is None :
188
- return { x for x in range (64 )}
184
+ return list ( range (64 ))
189
185
190
186
def _strtoul (s : str ) -> int :
191
- """Return -1 or integer sequence string starts with. """
192
- if len ( s ) == 0 :
187
+ """Return -1 or positive integer sequence string starts with, """
188
+ if not s :
193
189
return - 1
194
190
for idx , c in enumerate (s ):
195
- if not c .isdigit ():
191
+ if not ( c .isdigit () or ( idx == 0 and c in "+-" ) ):
196
192
break
197
193
if idx + 1 == len (s ):
198
194
idx += 1
199
195
return int (s [:idx ]) if idx > 0 else - 1
200
196
197
+ def parse_list_with_prefix (lst : str , prefix : str ) -> List [str ]:
198
+ rcs : List [str ] = []
199
+ for elem in lst .split ("," ):
200
+ # Repeated id results in empty set
201
+ if elem in rcs :
202
+ return cast (List [str ], [])
203
+ # Anything other but prefix is ignored
204
+ if not elem .startswith (prefix ):
205
+ break
206
+ rcs .append (elem )
207
+ return rcs
208
+
209
+ if var .startswith ("GPU-" ):
210
+ return parse_list_with_prefix (var , "GPU-" )
211
+ if var .startswith ("MIG-" ):
212
+ return parse_list_with_prefix (var , "MIG-" )
201
213
# CUDA_VISIBLE_DEVICES uses something like strtoul
202
214
# which makes `1gpu2,2ampere` is equivalent to `1,2`
203
- rc : Set [int ] = set ()
215
+ rc : List [int ] = []
204
216
for elem in var .split ("," ):
205
- rc .add (_strtoul (elem .strip ()))
217
+ x = _strtoul (elem .strip ())
218
+ # Repeated ordinal results in empty set
219
+ if x in rc :
220
+ return cast (List [int ], [])
221
+ # Negative value aborts the sequence
222
+ if x < 0 :
223
+ break
224
+ rc .append (x )
206
225
return rc
207
226
208
227
209
- # TODO: Remove once minimum supported PyTorch version is 1.13
228
+ # TODO: Remove once minimum supported PyTorch version is 2.0
210
229
def _raw_device_count_nvml () -> int :
211
- """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879 ."""
212
- from ctypes import c_int , CDLL
230
+ """Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed ."""
231
+ from ctypes import byref , c_int , CDLL
213
232
214
233
nvml_h = CDLL ("libnvidia-ml.so.1" )
215
234
rc = nvml_h .nvmlInit ()
216
235
if rc != 0 :
217
236
warnings .warn ("Can't initialize NVML" )
218
237
return - 1
219
- dev_arr = ( c_int * 1 ) (- 1 )
220
- rc = nvml_h .nvmlDeviceGetCount_v2 (dev_arr )
238
+ dev_count = c_int (- 1 )
239
+ rc = nvml_h .nvmlDeviceGetCount_v2 (byref ( dev_count ) )
221
240
if rc != 0 :
222
241
warnings .warn ("Can't get nvml device count" )
223
242
return - 1
224
243
del nvml_h
225
- return dev_arr [0 ]
244
+ return dev_count .value
245
+
226
246
247
+ # TODO: Remove once minimum supported PyTorch version is 2.0
248
+ def _raw_device_uuid_nvml () -> Optional [List [str ]]:
249
+ """Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed."""
250
+ from ctypes import byref , c_int , c_void_p , CDLL , create_string_buffer
227
251
228
- # TODO: Remove once minimum supported PyTorch version is 1.13
252
+ nvml_h = CDLL ("libnvidia-ml.so.1" )
253
+ rc = nvml_h .nvmlInit ()
254
+ if rc != 0 :
255
+ warnings .warn ("Can't initialize NVML" )
256
+ return None
257
+ dev_count = c_int (- 1 )
258
+ rc = nvml_h .nvmlDeviceGetCount_v2 (byref (dev_count ))
259
+ if rc != 0 :
260
+ warnings .warn ("Can't get nvml device count" )
261
+ return None
262
+ uuids : List [str ] = []
263
+ for idx in range (dev_count .value ):
264
+ dev_id = c_void_p ()
265
+ rc = nvml_h .nvmlDeviceGetHandleByIndex_v2 (idx , byref (dev_id ))
266
+ if rc != 0 :
267
+ warnings .warn ("Can't get device handle" )
268
+ return None
269
+ buf_len = 96
270
+ buf = create_string_buffer (buf_len )
271
+ rc = nvml_h .nvmlDeviceGetUUID (dev_id , buf , buf_len )
272
+ if rc != 0 :
273
+ warnings .warn ("Can't get device UUID" )
274
+ return None
275
+ uuids .append (buf .raw .decode ("ascii" ).strip ("\0 " ))
276
+ del nvml_h
277
+ return uuids
278
+
279
+
280
+ # TODO: Remove once minimum supported PyTorch version is 2.0
281
+ def _transform_uuid_to_ordinals (candidates : List [str ], uuids : List [str ]) -> List [int ]:
282
+ """Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials
283
+ IDs."""
284
+
285
+ def uuid_to_orinal (candidate : str , uuids : List [str ]) -> int :
286
+ best_match = - 1
287
+ for idx , uuid in enumerate (uuids ):
288
+ if not uuid .startswith (candidate ):
289
+ continue
290
+ # Ambigous candidate
291
+ if best_match != - 1 :
292
+ return - 1
293
+ best_match = idx
294
+ return best_match
295
+
296
+ rc : List [int ] = []
297
+ for candidate in candidates :
298
+ idx = uuid_to_orinal (candidate , uuids )
299
+ # First invalid ordinal stops parsing
300
+ if idx < 0 :
301
+ break
302
+ # Duplicates result in empty set
303
+ if idx in rc :
304
+ return cast (List [int ], [])
305
+ rc .append (idx )
306
+ return rc
307
+
308
+
309
+ # TODO: Remove once minimum supported PyTorch version is 2.0
229
310
def _device_count_nvml () -> int :
230
- """Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
311
+ """Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account.
312
+
313
+ Negative value is returned if NVML discovery or initialization has failed.
314
+ """
315
+ visible_devices = _parse_visible_devices ()
316
+ if not visible_devices :
317
+ return 0
231
318
try :
232
- raw_cnt = _raw_device_count_nvml ()
233
- if raw_cnt <= 0 :
234
- return raw_cnt
235
- return len (set (range (raw_cnt )).intersection (_parse_visible_devices ()))
319
+ if type (visible_devices [0 ]) is str :
320
+ # Skip MIG parsing
321
+ if visible_devices [0 ].startswith ("MIG-" ):
322
+ return - 1
323
+ uuids = _raw_device_uuid_nvml ()
324
+ if uuids is None :
325
+ return - 1
326
+ visible_devices = _transform_uuid_to_ordinals (cast (List [str ], visible_devices ), uuids )
327
+ else :
328
+ raw_cnt = _raw_device_count_nvml ()
329
+ if raw_cnt <= 0 :
330
+ return raw_cnt
331
+ # Trim the list up to a maximum available device
332
+ for idx , val in enumerate (visible_devices ):
333
+ if cast (int , val ) >= raw_cnt :
334
+ return idx
236
335
except OSError :
237
336
return - 1
238
337
except AttributeError :
239
338
return - 1
339
+ return len (visible_devices )
240
340
241
341
242
342
def _check_cuda_matmul_precision (device : torch .device ) -> None :
0 commit comments