Skip to content

Commit 7e8400d

Browse files
tchatonthomas
andauthored
[App] Add support for private data (#16738)
Co-authored-by: thomas <[email protected]>
1 parent 32e7137 commit 7e8400d

File tree

5 files changed

+41
-26
lines changed

5 files changed

+41
-26
lines changed

requirements/app/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
lightning-cloud>=0.5.24
1+
lightning-cloud>=0.5.26
22
packaging
33
typing-extensions>=4.0.0, <=4.4.0
44
deepdiff>=5.7.0, <6.2.4

src/lightning/app/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
12+
- Added `lightning connect data` to register data connection to private s3 buckets ([#16738](https://github.com/Lightning-AI/lightning/pull/16738))
1313

1414

1515
### Changed

src/lightning/app/cli/commands/ls.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import click
2121
import rich
22-
from fastapi import HTTPException
2322
from lightning_cloud.openapi import Externalv1LightningappInstance
2423
from rich.console import Console
2524
from rich.live import Live
@@ -65,7 +64,7 @@ def ls(path: Optional[str] = None, print: bool = True, use_live: bool = True) ->
6564
lines = f.readlines()
6665
root = lines[0].replace("\n", "")
6766

68-
client = LightningClient()
67+
client = LightningClient(retry=False)
6968
projects = client.projects_service_list_memberships()
7069

7170
if root == "/":
@@ -256,7 +255,7 @@ def _collect_artifacts(
256255
page_token=response.next_page_token,
257256
tokens=tokens,
258257
)
259-
except HTTPException:
258+
except Exception:
260259
# Note: This is triggered when the request is wrong.
261260
# This is currently happening due to looping through the user clusters.
262261
pass

src/lightning/app/cli/connect/data.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import ast
1516
import sys
1617

1718
import click
19+
import lightning_cloud
1820
import rich
1921
from rich.live import Live
2022
from rich.spinner import Spinner
@@ -29,20 +31,32 @@
2931

3032

3133
@click.argument("name", required=True)
32-
@click.argument("region", required=True)
33-
@click.argument("source", required=True)
34-
@click.argument("destination", required=False)
35-
@click.argument("project_name", required=False)
34+
@click.option("--region", help="The AWS region of your bucket. Example: `us-west-1`.", required=True)
35+
@click.option(
36+
"--source", help="The URL path to your AWS S3 folder. Example: `s3://pl-flash-data/images/`.", required=True
37+
)
38+
@click.option(
39+
"--secret_arn_name",
40+
help="The name of role stored as a secret on Lightning AI to access your data. "
41+
"Learn more with https://gist.github.com/tchaton/12ad4b788012e83c0eb35e6223ae09fc. "
42+
"Example: `my_role`.",
43+
required=False,
44+
)
45+
@click.option(
46+
"--destination", help="Where your data should appear in the cloud. Currently not supported.", required=False
47+
)
48+
@click.option("--project_name", help="The project name on which to create the data connection.", required=False)
3649
def connect_data(
3750
name: str,
3851
region: str,
3952
source: str,
53+
secret_arn_name: str = "",
4054
destination: str = "",
4155
project_name: str = "",
4256
) -> None:
4357
"""Create a new data connection."""
4458

45-
from lightning_cloud.openapi import ProjectIdDataConnectionsBody
59+
from lightning_cloud.openapi import Create, V1AwsDataConnection
4660

4761
if sys.platform == "win32":
4862
_error_and_exit("Data connection isn't supported on windows. Open an issue on Github.")
@@ -51,7 +65,7 @@ def connect_data(
5165

5266
live.stop()
5367

54-
client = LightningClient()
68+
client = LightningClient(retry=False)
5569
projects = client.projects_service_list_memberships()
5670

5771
project_id = None
@@ -71,12 +85,15 @@ def connect_data(
7185
)
7286

7387
try:
74-
_ = client.data_connection_service_create_data_connection(
75-
body=ProjectIdDataConnectionsBody(
88+
client.data_connection_service_create_data_connection(
89+
body=Create(
7690
name=name,
77-
region=region,
78-
source=source,
79-
destination=destination,
91+
aws=V1AwsDataConnection(
92+
region=region,
93+
source=source,
94+
destination=destination,
95+
secret_arn_name=secret_arn_name,
96+
),
8097
),
8198
project_id=project_id,
8299
)
@@ -86,8 +103,8 @@ def connect_data(
86103
# project_id=project_id,
87104
# id=response.id,
88105
# )
89-
# print(response)
90-
except Exception:
91-
_error_and_exit("The data connection creation failed.")
106+
except lightning_cloud.openapi.rest.ApiException as e:
107+
message = ast.literal_eval(e.body.decode("utf-8"))["message"]
108+
_error_and_exit(f"The data connection creation failed. Message: {message}")
92109

93110
rich.print(f"[green]Succeeded[/green]: You have created a new data connection {name}.")

tests/tests_app/cli/test_connect_data.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import sys
12
from unittest.mock import MagicMock
23

34
import pytest
45

56
from lightning.app.cli.connect import data
67

78

8-
@pytest.mark.skipif(True, reason="In progress")
9+
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
910
def test_connect_data_no_project(monkeypatch):
1011

1112
from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership
@@ -26,10 +27,10 @@ def test_connect_data_no_project(monkeypatch):
2627
_get_project.assert_called()
2728

2829

29-
@pytest.mark.skipif(True, reason="In progress")
30+
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
3031
def test_connect_data(monkeypatch):
3132

32-
from lightning_cloud.openapi import ProjectIdDataConnectionsBody, V1ListMembershipsResponse, V1Membership
33+
from lightning_cloud.openapi import Create, V1AwsDataConnection, V1ListMembershipsResponse, V1Membership
3334

3435
client = MagicMock()
3536
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
@@ -53,10 +54,8 @@ def test_connect_data(monkeypatch):
5354

5455
client.data_connection_service_create_data_connection.assert_called_with(
5556
project_id="project-id-0",
56-
body=ProjectIdDataConnectionsBody(
57-
destination="",
58-
region="us-east-1",
57+
body=Create(
5958
name="imagenet",
60-
source="s3://imagenet",
59+
aws=V1AwsDataConnection(destination="", region="us-east-1", source="s3://imagenet", secret_arn_name=""),
6160
),
6261
)

0 commit comments

Comments
 (0)