11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import os
15
- from unittest import mock
16
-
17
14
import pytest
18
15
import torch
19
16
20
17
from lightning_lite .strategies import DDPStrategy
21
18
22
19
23
20
@pytest .mark .parametrize (
24
- ["process_group_backend" , "env_var" , " device_str" , "expected_process_group_backend" ],
21
+ ["process_group_backend" , "device_str" , "expected_process_group_backend" ],
25
22
[
26
- pytest .param ("foo" , None , "cpu" , "foo" ),
27
- pytest .param ("foo" , "BAR" , "cpu" , "foo" ),
28
- pytest .param ("foo" , "BAR" , "cuda:0" , "foo" ),
29
- pytest .param (None , "BAR" , "cuda:0" , "BAR" ),
30
- pytest .param (None , None , "cuda:0" , "nccl" ),
31
- pytest .param (None , None , "cpu" , "gloo" ),
23
+ pytest .param ("foo" , "cpu" , "foo" ),
24
+ pytest .param ("foo" , "cuda:0" , "foo" ),
25
+ pytest .param (None , "cuda:0" , "nccl" ),
26
+ pytest .param (None , "cpu" , "gloo" ),
32
27
],
33
28
)
34
- def test_ddp_process_group_backend (process_group_backend , env_var , device_str , expected_process_group_backend ):
29
+ def test_ddp_process_group_backend (process_group_backend , device_str , expected_process_group_backend ):
35
30
"""Test settings for process group backend."""
36
31
37
32
class MockDDPStrategy (DDPStrategy ):
@@ -44,11 +39,4 @@ def root_device(self):
44
39
return self ._root_device
45
40
46
41
strategy = MockDDPStrategy (process_group_backend = process_group_backend , root_device = torch .device (device_str ))
47
- if not process_group_backend and env_var :
48
- with mock .patch .dict (os .environ , {"PL_TORCH_DISTRIBUTED_BACKEND" : env_var }):
49
- with pytest .deprecated_call (
50
- match = "Environment variable `PL_TORCH_DISTRIBUTED_BACKEND` was deprecated in v1.6"
51
- ):
52
- assert strategy ._get_process_group_backend () == expected_process_group_backend
53
- else :
54
- assert strategy ._get_process_group_backend () == expected_process_group_backend
42
+ assert strategy ._get_process_group_backend () == expected_process_group_backend
0 commit comments