@@ -94,7 +94,7 @@ def __init__(
94
94
self ,
95
95
accelerator : Optional [Union [str , Accelerator ]] = None ,
96
96
strategy : Optional [Union [str , Strategy ]] = None ,
97
- devices : Optional [ Union [List [int ], str , int ]] = None ,
97
+ devices : Union [List [int ], str , int ] = "auto" ,
98
98
num_nodes : int = 1 ,
99
99
precision : _PRECISION_INPUT = 32 ,
100
100
plugins : Optional [Union [_PLUGIN_INPUT , List [_PLUGIN_INPUT ]]] = None ,
@@ -103,7 +103,7 @@ def __init__(
103
103
# These arguments can be set through environment variables set by the CLI
104
104
accelerator = self ._argument_from_env ("accelerator" , accelerator , default = None )
105
105
strategy = self ._argument_from_env ("strategy" , strategy , default = None )
106
- devices = self ._argument_from_env ("devices" , devices , default = None )
106
+ devices = self ._argument_from_env ("devices" , devices , default = "auto" )
107
107
num_nodes = self ._argument_from_env ("num_nodes" , num_nodes , default = 1 )
108
108
precision = self ._argument_from_env ("precision" , precision , default = 32 )
109
109
@@ -277,9 +277,7 @@ def _check_config_and_set_final_flags(
277
277
self ._accelerator_flag = "cuda"
278
278
self ._parallel_devices = self ._strategy_flag .parallel_devices
279
279
280
- def _check_device_config_and_set_final_flags (
281
- self , devices : Optional [Union [List [int ], str , int ]], num_nodes : int
282
- ) -> None :
280
+ def _check_device_config_and_set_final_flags (self , devices : Union [List [int ], str , int ], num_nodes : int ) -> None :
283
281
self ._num_nodes_flag = int (num_nodes ) if num_nodes is not None else 1
284
282
self ._devices_flag = devices
285
283
@@ -348,7 +346,7 @@ def _set_parallel_devices_and_init_accelerator(self) -> None:
348
346
self ._parallel_devices = accelerator_cls .get_parallel_devices (self ._devices_flag )
349
347
350
348
def _set_devices_flag_if_auto_passed (self ) -> None :
351
- if self ._devices_flag == "auto" or self . _devices_flag is None :
349
+ if self ._devices_flag == "auto" :
352
350
self ._devices_flag = self .accelerator .auto_device_count ()
353
351
354
352
def _choose_and_init_cluster_environment (self ) -> ClusterEnvironment :
0 commit comments