Skip to content

Commit 6d92d7b

Browse files
Sherin ThomasBorda
authored andcommitted
[App] Connect and Disconnect node (#16700)
Connect and Disconnect node (cherry picked from commit 44557b9)
1 parent 921df6e commit 6d92d7b

File tree

3 files changed

+214
-9
lines changed

3 files changed

+214
-9
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import platform
15+
import shlex
16+
import subprocess
17+
import sys
18+
import time
19+
20+
import click
21+
import rich
22+
from rich.live import Live
23+
from rich.spinner import Spinner
24+
from rich.text import Text
25+
26+
from lightning_app.core.constants import get_lightning_cloud_url
27+
from lightning_app.utilities.app_helpers import Logger
28+
from lightning_app.utilities.cli_helpers import _error_and_exit
29+
30+
logger = Logger(__name__)
31+
32+
NETWORK_NAME = "lightning-byom"
33+
CMD_CREATE_NETWORK = f"docker network create {NETWORK_NAME}"
34+
35+
CODE_SERVER_CONTAINER = "code-server"
36+
CODE_SERVER_IMAGE = "ghcr.io/gridai/lightning-byom-code-server:v0.1"
37+
CODE_SERVER_PORT = 8443
38+
39+
LIGHTNING_CLOUD_URL = get_lightning_cloud_url()
40+
ROOT_DOMAIN = LIGHTNING_CLOUD_URL.split("//")[1]
41+
CLOUD_PROXY_HOST = f"byom.{ROOT_DOMAIN}"
42+
43+
LIGHTNING_DAEMON_NODE_PREFIX = ""
44+
LIGHTNING_DAEMON_CONTAINER = "lightning-daemon"
45+
LIGHTNING_DAEMON_IMAGE = "ghcr.io/gridai/lightning-daemon:v0.1"
46+
47+
48+
def get_code_server_docker_command() -> str:
49+
return (
50+
f"docker run "
51+
f"-p {CODE_SERVER_PORT}:{CODE_SERVER_PORT} "
52+
f"--net {NETWORK_NAME} "
53+
f"--name {CODE_SERVER_CONTAINER} "
54+
f"--rm {CODE_SERVER_IMAGE}"
55+
)
56+
57+
58+
def get_lightning_daemon_command(node_prefix: str) -> str:
59+
return (
60+
f"docker run "
61+
f"-e LIGHTNING_BYOM_CLOUD_PROXY_HOST=https://{node_prefix}.{CLOUD_PROXY_HOST} "
62+
f"-e LIGHTNING_BYOM_RESOURCE_URL=http://{CODE_SERVER_CONTAINER}:{CODE_SERVER_PORT} "
63+
f"--net {NETWORK_NAME} "
64+
f"--name {LIGHTNING_DAEMON_CONTAINER} "
65+
f"--rm {LIGHTNING_DAEMON_IMAGE}"
66+
)
67+
68+
69+
@click.argument("name", required=True)
70+
def connect_node(name: str) -> None:
71+
"""Create a new node connection."""
72+
# print system architecture and OS
73+
if sys.platform != "darwin" or platform.processor() != "arm":
74+
_error_and_exit("Node connection is only supported from M1 Macs at the moment")
75+
76+
# check if docker client is installed or not
77+
try:
78+
subprocess.run("docker --version", shell=True, check=True, capture_output=True)
79+
except subprocess.CalledProcessError:
80+
_error_and_exit("Docker client is not installed. Please install docker and try again.")
81+
82+
# check if docker daemon is running or not
83+
try:
84+
subprocess.run("docker ps", shell=True, check=True, capture_output=True)
85+
except subprocess.CalledProcessError:
86+
_error_and_exit("Docker daemon is not running. Please start docker and try again.")
87+
88+
if "lightning.ai" in CLOUD_PROXY_HOST:
89+
_error_and_exit("Node connection isn't publicly available. Open an issue on Github.")
90+
91+
with Live(Spinner("point", text=Text("pending...", style="white")), transient=True) as live:
92+
# run network creation in the background
93+
out = subprocess.run(CMD_CREATE_NETWORK, shell=True, capture_output=True)
94+
error = out.stderr
95+
if error:
96+
if "already exists" not in str(error):
97+
live.stop()
98+
rich.print(f"[red]Failed[/red]: network creation failed with error: {str(error)}")
99+
return
100+
101+
# if code server is already running, ignore.
102+
# If not, but container exists, remove it and run. Otherwise, run.
103+
out = subprocess.run(
104+
f"docker ps -q -f name={CODE_SERVER_CONTAINER}", shell=True, check=True, capture_output=True
105+
)
106+
if out.stdout:
107+
pass
108+
else:
109+
out = subprocess.run(
110+
f"docker container ls -aq -f name={CODE_SERVER_CONTAINER}", shell=True, check=True, capture_output=True
111+
)
112+
if out.stdout:
113+
subprocess.run(f"docker rm -f {CODE_SERVER_CONTAINER}", shell=True, check=True)
114+
else:
115+
live.update(Spinner("point", text=Text("pulling code server image", style="white")))
116+
out = subprocess.run(f"docker pull {CODE_SERVER_IMAGE}", shell=True, check=True, capture_output=True)
117+
error = out.stderr
118+
if error:
119+
live.stop()
120+
rich.print(f"[red]Failed[/red]: code server image pull failed with error: {str(error)}")
121+
return
122+
cmd = get_code_server_docker_command()
123+
live.update(Spinner("point", text=Text("running code server", style="white")))
124+
_ = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
125+
126+
# if lightning daemon is already running, ignore.
127+
# If not, but container exists, remove it and run. Otherwise, run.
128+
out = subprocess.run(
129+
f"docker ps -q -f name={LIGHTNING_DAEMON_CONTAINER}", shell=True, check=True, capture_output=True
130+
)
131+
if out.stdout:
132+
pass
133+
else:
134+
out = subprocess.run(
135+
f"docker container ls -aq -f name={LIGHTNING_DAEMON_CONTAINER}",
136+
shell=True,
137+
check=True,
138+
capture_output=True,
139+
)
140+
if out.stdout:
141+
subprocess.run(f"docker rm -f {LIGHTNING_DAEMON_CONTAINER}", shell=True, check=True)
142+
else:
143+
live.update(Spinner("point", text=Text("pulling lightning daemon image", style="white")))
144+
out = subprocess.run(
145+
f"docker pull {LIGHTNING_DAEMON_IMAGE}", shell=True, check=True, capture_output=True
146+
)
147+
error = out.stderr
148+
if error:
149+
live.stop()
150+
rich.print(f"[red]Failed[/red]: lightnign daemon image pull failed with error: {str(error)}")
151+
return
152+
cmd = get_lightning_daemon_command(name)
153+
live.update(Spinner("point", text=Text("running lightning daemon", style="white")))
154+
_ = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
155+
156+
# wait until if both docker containers are running
157+
code_server_running = False
158+
lightning_daemon_running = False
159+
live.update(Spinner("point", text=Text("establishing connection ...", style="white")))
160+
connection_check_start_time = time.time()
161+
162+
# wait for 30 seconds for connection to be established
163+
while time.time() - connection_check_start_time < 30:
164+
out = subprocess.run(
165+
f"docker container ls -f name={CODE_SERVER_CONTAINER} " + '--format "{{.Status}}"',
166+
shell=True,
167+
check=True,
168+
capture_output=True,
169+
)
170+
171+
if "Up" in str(out.stdout):
172+
code_server_running = True
173+
174+
out = subprocess.run(
175+
f"docker container ls -f name={LIGHTNING_DAEMON_CONTAINER} " + '--format "{{.Status}}"',
176+
shell=True,
177+
check=True,
178+
capture_output=True,
179+
)
180+
if "Up" in str(out.stdout):
181+
lightning_daemon_running = True
182+
183+
if code_server_running and lightning_daemon_running:
184+
break
185+
186+
# Sleeping for 0.5 seconds
187+
time.sleep(0.5)
188+
rich.print(
189+
f"[green]Succeeded[/green]: node {name} has been connected to lightning. \n "
190+
f"Go to https://{name}.{CLOUD_PROXY_HOST} to access the node."
191+
)
192+
193+
194+
@click.argument("name", required=True)
195+
def disconnect_node(name: str) -> None:
196+
# disconnect node stop and remove the docker containers
197+
with Live(Spinner("point", text=Text("disconnecting node...", style="white")), transient=True):
198+
subprocess.run(f"docker stop {CODE_SERVER_CONTAINER}", shell=True, capture_output=True)
199+
subprocess.run(f"docker stop {LIGHTNING_DAEMON_CONTAINER}", shell=True, capture_output=True)
200+
rich.print(f"[green]Succeeded[/green]: node {name} has been disconnected from lightning.")

src/lightning_app/cli/lightning_cli.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
disconnect_app,
4646
)
4747
from lightning_app.cli.connect.data import connect_data
48+
from lightning_app.cli.connect.node import connect_node, disconnect_node
4849
from lightning_app.cli.lightning_cli_create import create
4950
from lightning_app.cli.lightning_cli_delete import delete
5051
from lightning_app.cli.lightning_cli_list import get_list
@@ -136,11 +137,13 @@ def disconnect() -> None:
136137

137138
connect.command("app")(connect_app)
138139
disconnect.command("app")(disconnect_app)
139-
connect.command("data")(connect_data)
140-
_main.command()(ls)
141-
_main.command()(cd)
142-
_main.command()(cp)
143-
_main.command()(pwd)
140+
connect.command("node", hidden=True)(connect_node)
141+
disconnect.command("node", hidden=True)(disconnect_node)
142+
connect.command("data", hidden=True)(connect_data)
143+
_main.command(hidden=True)(ls)
144+
_main.command(hidden=True)(cd)
145+
_main.command(hidden=True)(cp)
146+
_main.command(hidden=True)(pwd)
144147
show.command()(logs)
145148

146149

tests/tests_app/cli/test_connect_data.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
import sys
21
from unittest.mock import MagicMock
32

43
import pytest
5-
from lightning_cloud.openapi import ProjectIdDataConnectionsBody, V1ListMembershipsResponse, V1Membership
64

75
from lightning_app.cli.connect import data
86

97

10-
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
8+
@pytest.mark.skipif(True, reason="In progress")
119
def test_connect_data_no_project(monkeypatch):
1210

11+
from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership
12+
1313
client = MagicMock()
1414
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(memberships=[])
1515
monkeypatch.setattr(data, "LightningClient", MagicMock(return_value=client))
@@ -26,9 +26,11 @@ def test_connect_data_no_project(monkeypatch):
2626
_get_project.assert_called()
2727

2828

29-
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows yet")
29+
@pytest.mark.skipif(True, reason="In progress")
3030
def test_connect_data(monkeypatch):
3131

32+
from lightning_cloud.openapi import ProjectIdDataConnectionsBody, V1ListMembershipsResponse, V1Membership
33+
3234
client = MagicMock()
3335
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
3436
memberships=[

0 commit comments

Comments
 (0)