Skip to content

Commit 7f92d5c

Browse files
authored
[App] Refactor plugins to be a standalone LightningPlugin (#16765)
1 parent ac5fa03 commit 7f92d5c

File tree

12 files changed

+397
-237
lines changed

12 files changed

+397
-237
lines changed

src/lightning/app/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from lightning.app.core.app import LightningApp # noqa: E402
3535
from lightning.app.core.flow import LightningFlow # noqa: E402
36+
from lightning.app.core.plugin import LightningPlugin # noqa: E402
3637
from lightning.app.core.work import LightningWork # noqa: E402
3738
from lightning.app.perf import pdb # noqa: E402
3839
from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
@@ -46,4 +47,4 @@
4647
_PACKAGE_ROOT = os.path.dirname(__file__)
4748
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_PACKAGE_ROOT))
4849

49-
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "BuildConfig", "CloudCompute", "pdb"]
50+
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin", "BuildConfig", "CloudCompute", "pdb"]

src/lightning/app/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from lightning.app.core.app import LightningApp
22
from lightning.app.core.flow import LightningFlow
3+
from lightning.app.core.plugin import LightningPlugin
34
from lightning.app.core.work import LightningWork
45

5-
__all__ = ["LightningApp", "LightningFlow", "LightningWork"]
6+
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin"]

src/lightning/app/core/flow.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from deepdiff import DeepHash
2323

24-
from lightning.app.core.plugin import Plugin
2524
from lightning.app.core.work import LightningWork
2625
from lightning.app.frontend import Frontend
2726
from lightning.app.storage import Path
@@ -741,22 +740,6 @@ def configure_api(self):
741740
"""
742741
raise NotImplementedError
743742

744-
def configure_plugins(self) -> Optional[List[Dict[str, Plugin]]]:
745-
"""Configure the plugins of this LightningFlow.
746-
747-
Returns a list of dictionaries mapping a plugin name to a :class:`lightning_app.core.plugin.Plugin`.
748-
749-
.. code-block:: python
750-
751-
class Flow(LightningFlow):
752-
def __init__(self):
753-
super().__init__()
754-
755-
def configure_plugins(self):
756-
return [{"my_plugin_name": MyPlugin()}]
757-
"""
758-
pass
759-
760743
def state_dict(self):
761744
"""Returns the current flow state but not its children."""
762745
return {

src/lightning/app/core/plugin.py

Lines changed: 100 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import tarfile
1516
import tempfile
1617
from pathlib import Path
17-
from typing import Any, Dict, Optional
18+
from typing import Dict, List, Optional
19+
from urllib.parse import urlparse
1820

1921
import requests
2022
import uvicorn
@@ -23,77 +25,39 @@
2325
from pydantic import BaseModel
2426

2527
from lightning.app.utilities.app_helpers import Logger
26-
from lightning.app.utilities.cloud import _get_project
2728
from lightning.app.utilities.component import _set_flow_context
2829
from lightning.app.utilities.enum import AppStage
29-
from lightning.app.utilities.network import LightningClient
30+
from lightning.app.utilities.load_app import _load_plugin_from_file
3031

3132
logger = Logger(__name__)
3233

3334

34-
class Plugin:
35-
"""A ``Plugin`` is a single-file Python class that can be executed within a cloudspace to perform actions."""
35+
class LightningPlugin:
36+
"""A ``LightningPlugin`` is a single-file Python class that can be executed within a cloudspace to perform
37+
actions."""
3638

3739
def __init__(self) -> None:
38-
self.app_url = None
40+
self.project_id = ""
41+
self.cloudspace_id = ""
42+
self.cluster_id = ""
3943

40-
def run(self, name: str, entrypoint: str) -> None:
41-
"""Override with the logic to execute on the client side."""
44+
def run(self, *args: str, **kwargs: str) -> None:
45+
"""Override with the logic to execute on the cloudspace."""
4246

43-
def run_app_command(self, command_name: str, config: Optional[BaseModel] = None) -> Dict[str, Any]:
44-
"""Run a command on the app associated with this plugin.
47+
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> None:
48+
"""Run a job in the cloudspace associated with this plugin.
4549
4650
Args:
47-
command_name: The name of the command to run.
48-
config: The command config or ``None`` if the command doesn't require configuration.
51+
name: The name of the job.
52+
app_entrypoint: The path of the file containing the app to run.
53+
env_vars: Additional env vars to set when running the app.
4954
"""
50-
if self.app_url is None:
51-
raise RuntimeError("The plugin must be set up before `run_app_command` can be called.")
52-
53-
command = command_name.replace(" ", "_")
54-
resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None)
55-
if resp.status_code != 200:
56-
try:
57-
detail = str(resp.json())
58-
except Exception:
59-
detail = "Internal Server Error"
60-
raise RuntimeError(f"Failed with status code {resp.status_code}. Detail: {detail}")
61-
62-
return resp.json()
63-
64-
def _setup(self, app_id: str) -> None:
65-
client = LightningClient()
66-
project_id = _get_project(client).project_id
67-
response = client.lightningapp_instance_service_list_lightningapp_instances(
68-
project_id=project_id, app_id=app_id
69-
)
70-
if len(response.lightningapps) > 1:
71-
raise RuntimeError(f"Found multiple apps with ID: {app_id}")
72-
if len(response.lightningapps) == 0:
73-
raise RuntimeError(f"Found no apps with ID: {app_id}")
74-
self.app_url = response.lightningapps[0].status.url
75-
76-
77-
class _Run(BaseModel):
78-
plugin_name: str
79-
project_id: str
80-
cloudspace_id: str
81-
name: str
82-
entrypoint: str
83-
cluster_id: Optional[str] = None
84-
app_id: Optional[str] = None
85-
86-
87-
def _run_plugin(run: _Run) -> None:
88-
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
89-
if run.app_id is None and run.plugin_name == "app":
9055
from lightning.app.runners.cloud import CloudRuntime
9156

92-
# TODO: App dispatch should be a plugin
93-
# Dispatch the run
57+
# Dispatch the job
9458
_set_flow_context()
9559

96-
entrypoint_file = Path("/content") / run.entrypoint
60+
entrypoint_file = Path(app_entrypoint)
9761

9862
app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute()))
9963

@@ -103,54 +67,101 @@ def _run_plugin(run: _Run) -> None:
10367
app=app,
10468
entrypoint=entrypoint_file,
10569
start_server=True,
106-
env_vars={},
70+
env_vars=env_vars if env_vars is not None else {},
10771
secrets={},
10872
run_app_comment_commands=True,
10973
)
11074
# Used to indicate Lightning has been dispatched
11175
os.environ["LIGHTNING_DISPATCHED"] = "1"
11276

77+
runtime.cloudspace_dispatch(
78+
project_id=self.project_id,
79+
cloudspace_id=self.cloudspace_id,
80+
name=name,
81+
cluster_id=self.cluster_id,
82+
)
83+
84+
def _setup(
85+
self,
86+
project_id: str,
87+
cloudspace_id: str,
88+
cluster_id: str,
89+
) -> None:
90+
self.project_id = project_id
91+
self.cloudspace_id = cloudspace_id
92+
self.cluster_id = cluster_id
93+
94+
95+
class _Run(BaseModel):
96+
plugin_entrypoint: str
97+
source_code_url: str
98+
project_id: str
99+
cloudspace_id: str
100+
cluster_id: str
101+
plugin_arguments: Dict[str, str]
102+
103+
104+
def _run_plugin(run: _Run) -> List:
105+
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
106+
with tempfile.TemporaryDirectory() as tmpdir:
107+
download_path = os.path.join(tmpdir, "source.tar.gz")
108+
source_path = os.path.join(tmpdir, "source")
109+
os.makedirs(source_path)
110+
111+
# Download the tarball
113112
try:
114-
runtime.cloudspace_dispatch(
115-
project_id=run.project_id,
116-
cloudspace_id=run.cloudspace_id,
117-
name=run.name,
118-
cluster_id=run.cluster_id,
119-
)
113+
# Sometimes the URL gets encoded, so we parse it here
114+
source_code_url = urlparse(run.source_code_url).geturl()
115+
116+
response = requests.get(source_code_url)
117+
118+
with open(download_path, "wb") as f:
119+
f.write(response.content)
120120
except Exception as e:
121-
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
122-
elif run.app_id is not None:
123-
from lightning.app.utilities.cli_helpers import _LightningAppOpenAPIRetriever
124-
from lightning.app.utilities.commands.base import _download_command
121+
raise HTTPException(
122+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
123+
detail=f"Error downloading plugin source: {str(e)}.",
124+
)
125125

126-
retriever = _LightningAppOpenAPIRetriever(run.app_id)
126+
# Extract
127+
try:
128+
with tarfile.open(download_path, "r:gz") as tf:
129+
tf.extractall(source_path)
130+
except Exception as e:
131+
raise HTTPException(
132+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
133+
detail=f"Error extracting plugin source: {str(e)}.",
134+
)
127135

128-
metadata = retriever.api_commands[run.plugin_name] # type: ignore
136+
# Import the plugin
137+
try:
138+
plugin = _load_plugin_from_file(os.path.join(source_path, run.plugin_entrypoint))
139+
except Exception as e:
140+
raise HTTPException(
141+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(e)}."
142+
)
129143

130-
with tempfile.TemporaryDirectory() as tmpdir:
144+
# Ensure that apps are dispatched from the temp directory
145+
cwd = os.getcwd()
146+
os.chdir(source_path)
131147

132-
target_file = os.path.join(tmpdir, f"{run.plugin_name}.py")
133-
plugin = _download_command(
134-
run.plugin_name,
135-
metadata["cls_path"],
136-
metadata["cls_name"],
137-
run.app_id,
138-
target_file=target_file,
148+
# Setup and run the plugin
149+
try:
150+
plugin._setup(
151+
project_id=run.project_id,
152+
cloudspace_id=run.cloudspace_id,
153+
cluster_id=run.cluster_id,
139154
)
155+
plugin.run(**run.plugin_arguments)
156+
except Exception as e:
157+
raise HTTPException(
158+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}."
159+
)
160+
finally:
161+
os.chdir(cwd)
140162

141-
if isinstance(plugin, Plugin):
142-
plugin._setup(app_id=run.app_id)
143-
plugin.run(run.name, run.entrypoint)
144-
else:
145-
# This should never be possible but we check just in case
146-
raise HTTPException(
147-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
148-
detail=f"The plugin {run.plugin_name} is an incorrect type.",
149-
)
150-
else:
151-
raise HTTPException(
152-
status_code=status.HTTP_400_BAD_REQUEST, detail="App ID must be specified unless `plugin_name='app'`."
153-
)
163+
# TODO: Return actions from the plugin here
164+
return []
154165

155166

156167
def _start_plugin_server(host: str, port: int) -> None:

src/lightning/app/runners/cloud.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def open(self, name: str, cluster_id: Optional[str] = None):
143143
ignore_functions = self._resolve_open_ignore_functions()
144144
repo = self._resolve_repo(root, ignore_functions)
145145
project = self._resolve_project()
146-
existing_cloudspaces = self._resolve_existing_cloudspaces(project, cloudspace_config.name)
146+
existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name)
147147
cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces)
148148
existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance(
149149
cluster_id, project.project_id, existing_cloudspaces
@@ -213,7 +213,7 @@ def cloudspace_dispatch(
213213
project_id: str,
214214
cloudspace_id: str,
215215
name: str,
216-
cluster_id: str = None,
216+
cluster_id: str,
217217
):
218218
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
219219
such as the project and cluster IDs that are instead passed directly.
@@ -232,10 +232,10 @@ def cloudspace_dispatch(
232232
# Dispatch in four phases: resolution, validation, spec creation, API transactions
233233
# Resolution
234234
root = self._resolve_root()
235-
ignore_functions = self._resolve_open_ignore_functions()
236-
repo = self._resolve_repo(root, ignore_functions)
237-
cloudspace = self._resolve_cloudspace(project_id, cloudspace_id)
238-
cluster_id = self._resolve_cluster_id(cluster_id, project_id, [cloudspace])
235+
repo = self._resolve_repo(root)
236+
self._resolve_cloudspace(project_id, cloudspace_id)
237+
existing_instances = self._resolve_run_instances_by_name(project_id, name)
238+
name = self._resolve_run_name(name, existing_instances)
239239
queue_server_type = self._resolve_queue_server_type()
240240

241241
self.app._update_index_file()
@@ -294,7 +294,7 @@ def dispatch(
294294
root = self._resolve_root()
295295
repo = self._resolve_repo(root)
296296
project = self._resolve_project()
297-
existing_cloudspaces = self._resolve_existing_cloudspaces(project, cloudspace_config.name)
297+
existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name)
298298
cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces)
299299
existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance(
300300
cluster_id, project.project_id, existing_cloudspaces
@@ -478,11 +478,11 @@ def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpa
478478
id=cloudspace_id,
479479
)
480480

481-
def _resolve_existing_cloudspaces(self, project, cloudspace_name: str) -> List[V1CloudSpace]:
481+
def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]:
482482
"""Lists all the cloudspaces with a name matching the provided cloudspace name."""
483483
# TODO: Add pagination, otherwise this could break if users have a lot of cloudspaces.
484484
existing_cloudspaces = self.backend.client.cloud_space_service_list_cloud_spaces(
485-
project_id=project.project_id
485+
project_id=project_id
486486
).cloudspaces
487487

488488
# Search for cloudspaces with the given name (possibly with some random characters appended)
@@ -521,6 +521,14 @@ def _resolve_existing_run_instance(
521521
break
522522
return existing_cloudspace, existing_run_instance
523523

524+
def _resolve_run_instances_by_name(self, project_id: str, name: str) -> List[Externalv1LightningappInstance]:
525+
"""Get all existing instances in the given project with the given name."""
526+
run_instances = self.backend.client.lightningapp_instance_service_list_lightningapp_instances(
527+
project_id=project_id,
528+
).lightningapps
529+
530+
return [run_instance for run_instance in run_instances if run_instance.display_name == name]
531+
524532
def _resolve_cloudspace_name(
525533
self,
526534
cloudspace_name: str,
@@ -529,16 +537,29 @@ def _resolve_cloudspace_name(
529537
) -> str:
530538
"""If there are existing cloudspaces but not on the cluster - choose a randomised name."""
531539
if len(existing_cloudspaces) > 0 and existing_cloudspace is None:
532-
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
533-
534540
name_exists = True
535541
while name_exists:
536-
random_name = cloudspace_name + "-" + "".join(random.sample(letters, 4))
542+
random_name = cloudspace_name + "-" + "".join(random.sample(string.ascii_letters, 4))
537543
name_exists = any([app.name == random_name for app in existing_cloudspaces])
538544

539545
cloudspace_name = random_name
540546
return cloudspace_name
541547

548+
def _resolve_run_name(
549+
self,
550+
name: str,
551+
existing_instances: List[Externalv1LightningappInstance],
552+
) -> str:
553+
"""If there are existing instances with the same name - choose a randomised name."""
554+
if len(existing_instances) > 0:
555+
name_exists = True
556+
while name_exists:
557+
random_name = name + "-" + "".join(random.sample(string.ascii_letters, 4))
558+
name_exists = any([app.name == random_name for app in existing_instances])
559+
560+
name = random_name
561+
return name
562+
542563
def _resolve_queue_server_type(self) -> V1QueueServerType:
543564
"""Resolve the cloud queue type from the environment."""
544565
queue_server_type = V1QueueServerType.UNSPECIFIED

0 commit comments

Comments
 (0)