Skip to content

Commit d21b899

Browse files
authored
Fix multinode cloud component (#15965)
* fix multinode cloud component * add tests
1 parent c748f82 commit d21b899

File tree

5 files changed

+45
-2
lines changed

5 files changed

+45
-2
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515

1616
- Changed the default port of `PythonServer` from `7777` to a free port at runtime ([#15966](https://github.com/Lightning-AI/lightning/pull/15966))
1717

18+
1819
- Remove the `AutoScaler` dependency `aiohttp` from the base requirements ([#15971](https://github.com/Lightning-AI/lightning/pull/15971))
1920

2021

@@ -30,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3031

3132
### Fixed
3233

34+
- Fixed MultiNode Component to use separate cloud computes ([#15965](https://github.com/Lightning-AI/lightning/pull/15965))
35+
36+
3337
- Fixed `AutoScaler` failing due to port collision across works ([#15966](https://github.com/Lightning-AI/lightning/pull/15966))
3438

3539

@@ -80,6 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8084
- Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951))
8185
- Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963))
8286

87+
8388
## [1.8.3] - 2022-11-22
8489

8590
### Changed

src/lightning_app/components/multi_node/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def run(
6666
*[
6767
work_cls(
6868
*work_args,
69-
cloud_compute=cloud_compute,
69+
cloud_compute=cloud_compute.clone(),
7070
**work_kwargs,
7171
parallel=True,
7272
)

src/lightning_app/utilities/packaging/cloud_compute.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __post_init__(self) -> None:
8282

8383
# All `default` CloudCompute are identified in the same way.
8484
if self._internal_id is None:
85-
self._internal_id = "default" if self.name == "default" else uuid4().hex[:7]
85+
self._internal_id = self._generate_id()
8686

8787
# Internal arguments for now.
8888
self.preemptible = False
@@ -118,6 +118,14 @@ def id(self) -> Optional[str]:
118118
def is_default(self) -> bool:
119119
return self.name == "default"
120120

121+
def _generate_id(self):
122+
return "default" if self.name == "default" else uuid4().hex[:7]
123+
124+
def clone(self):
125+
new_dict = self.to_dict()
126+
new_dict["_internal_id"] = self._generate_id()
127+
return self.from_dict(new_dict)
128+
121129

122130
def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
123131
if isinstance(mounts, (list, tuple, set)):

tests/tests_app/components/multi_node/test_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from re import escape
2+
from unittest import mock
23

34
import pytest
45
from tests_app.helpers.utils import no_warning_call
@@ -17,3 +18,14 @@ def run(self):
1718

1819
with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
1920
MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu"))
21+
22+
23+
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", mock.Mock(return_value=True))
24+
def test_multi_node_separate_cloud_computes():
25+
class Work(LightningWork):
26+
def run(self):
27+
pass
28+
29+
m = MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))
30+
31+
assert len({w.cloud_compute._internal_id for w in m.ws}) == len(m.ws)

tests/tests_app/utilities/packaging/test_cloud_compute.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,21 @@ def test_cloud_compute_with_non_unique_mount_root_dirs():
4141

4242
with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"):
4343
CloudCompute("gpu", mounts=[mount_1, mount_2])
44+
45+
46+
def test_cloud_compute_clone():
47+
c1 = CloudCompute("gpu")
48+
c2 = c1.clone()
49+
50+
assert isinstance(c2, CloudCompute)
51+
52+
c1_dict = c1.to_dict()
53+
c2_dict = c2.to_dict()
54+
55+
assert len(c1_dict) == len(c2_dict)
56+
57+
for k in c1_dict.keys():
58+
if k == "_internal_id":
59+
assert c1_dict[k] != c2_dict[k]
60+
else:
61+
assert c1_dict[k] == c2_dict[k]

0 commit comments

Comments
 (0)