Skip to content

Commit f3d4f9f

Browse files
committed
Set the explicit default for devices Trainer argument
1 parent a125638 commit f3d4f9f

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

src/lightning_fabric/connector.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
self,
9595
accelerator: Optional[Union[str, Accelerator]] = None,
9696
strategy: Optional[Union[str, Strategy]] = None,
97-
devices: Optional[Union[List[int], str, int]] = None,
97+
devices: Union[List[int], str, int] = "auto",
9898
num_nodes: int = 1,
9999
precision: _PRECISION_INPUT = 32,
100100
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
@@ -103,7 +103,7 @@ def __init__(
103103
# These arguments can be set through environment variables set by the CLI
104104
accelerator = self._argument_from_env("accelerator", accelerator, default=None)
105105
strategy = self._argument_from_env("strategy", strategy, default=None)
106-
devices = self._argument_from_env("devices", devices, default=None)
106+
devices = self._argument_from_env("devices", devices, default="auto")
107107
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
108108
precision = self._argument_from_env("precision", precision, default=32)
109109

@@ -277,9 +277,7 @@ def _check_config_and_set_final_flags(
277277
self._accelerator_flag = "cuda"
278278
self._parallel_devices = self._strategy_flag.parallel_devices
279279

280-
def _check_device_config_and_set_final_flags(
281-
self, devices: Optional[Union[List[int], str, int]], num_nodes: int
282-
) -> None:
280+
def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
283281
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
284282
self._devices_flag = devices
285283

@@ -348,7 +346,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
348346
self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag)
349347

350348
def _set_devices_flag_if_auto_passed(self) -> None:
351-
if self._devices_flag == "auto" or self._devices_flag is None:
349+
if self._devices_flag == "auto":
352350
self._devices_flag = self.accelerator.auto_device_count()
353351

354352
def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:

src/lightning_fabric/fabric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
self,
8080
accelerator: Optional[Union[str, Accelerator]] = None,
8181
strategy: Optional[Union[str, Strategy]] = None,
82-
devices: Optional[Union[List[int], str, int]] = None,
82+
devices: Union[List[int], str, int] = "auto",
8383
num_nodes: int = 1,
8484
precision: _PRECISION_INPUT = 32,
8585
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,

src/pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
class AcceleratorConnector:
9191
def __init__(
9292
self,
93-
devices: Optional[Union[List[int], str, int]] = None,
93+
devices: Union[List[int], str, int] = "auto",
9494
num_nodes: int = 1,
9595
accelerator: Optional[Union[str, Accelerator]] = None,
9696
strategy: Optional[Union[str, Strategy]] = None,
@@ -364,7 +364,7 @@ def _check_config_and_set_final_flags(
364364

365365
def _check_device_config_and_set_final_flags(
366366
self,
367-
devices: Optional[Union[List[int], str, int]],
367+
devices: Union[List[int], str, int],
368368
num_nodes: int,
369369
) -> None:
370370
self._num_nodes_flag = int(num_nodes) if num_nodes is not None else 1
@@ -444,7 +444,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
444444
self._parallel_devices = accelerator_cls.get_parallel_devices(self._devices_flag)
445445

446446
def _set_devices_flag_if_auto_passed(self) -> None:
447-
if self._devices_flag == "auto" or self._devices_flag is None:
447+
if self._devices_flag == "auto":
448448
self._devices_flag = self.accelerator.auto_device_count()
449449

450450
def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:

0 commit comments

Comments
 (0)