Skip to content
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/app/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lightning-cloud>=0.5.24
lightning-cloud>=0.5.26
packaging
typing-extensions>=4.0.0, <=4.4.0
deepdiff>=5.7.0, <6.2.4
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added `lightning connect data` to register data connection to private s3 buckets ([#16738](https://github.com/Lightning-AI/lightning/pull/16738))


### Changed
Expand Down
5 changes: 2 additions & 3 deletions src/lightning/app/cli/commands/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import click
import rich
from fastapi import HTTPException
from lightning_cloud.openapi import Externalv1LightningappInstance
from rich.console import Console
from rich.live import Live
Expand Down Expand Up @@ -65,7 +64,7 @@ def ls(path: Optional[str] = None, print: bool = True, use_live: bool = True) ->
lines = f.readlines()
root = lines[0].replace("\n", "")

client = LightningClient()
client = LightningClient(retry=False)
projects = client.projects_service_list_memberships()

if root == "/":
Expand Down Expand Up @@ -256,7 +255,7 @@ def _collect_artifacts(
page_token=response.next_page_token,
tokens=tokens,
)
except HTTPException:
except Exception:
# Note: This is triggered when the request is wrong.
# This is currently happening due to looping through the user clusters.
pass
Expand Down
45 changes: 31 additions & 14 deletions src/lightning/app/cli/connect/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import sys

import click
import lightning_cloud
import rich
from rich.live import Live
from rich.spinner import Spinner
Expand All @@ -29,20 +31,32 @@


@click.argument("name", required=True)
@click.argument("region", required=True)
@click.argument("source", required=True)
@click.argument("destination", required=False)
@click.argument("project_name", required=False)
@click.option("--region", help="The AWS region of your bucket. Example: `us-west-1`.", required=True)
@click.option(
"--source", help="The URL path to your AWS S3 folder. Example: `s3://pl-flash-data/images/`.", required=True
)
@click.option(
"--secret_arn_name",
help="The name of role stored as a secret on Lightning AI to access your data. "
"Learn more with https://gist.github.com/tchaton/12ad4b788012e83c0eb35e6223ae09fc. "
"Example: `my_role`.",
required=False,
)
@click.option(
"--destination", help="Where your data should appear in the cloud. Currently not supported.", required=False
)
@click.option("--project_name", help="The project name on which to create the data connection.", required=False)
def connect_data(
name: str,
region: str,
source: str,
secret_arn_name: str = "",
destination: str = "",
project_name: str = "",
) -> None:
"""Create a new data connection."""

from lightning_cloud.openapi import ProjectIdDataConnectionsBody
from lightning_cloud.openapi import Create, V1AwsDataConnection

if sys.platform == "win32":
_error_and_exit("Data connection isn't supported on windows. Open an issue on Github.")
Expand All @@ -51,7 +65,7 @@ def connect_data(

live.stop()

client = LightningClient()
client = LightningClient(retry=False)
projects = client.projects_service_list_memberships()

project_id = None
Expand All @@ -71,12 +85,15 @@ def connect_data(
)

try:
_ = client.data_connection_service_create_data_connection(
body=ProjectIdDataConnectionsBody(
client.data_connection_service_create_data_connection(
body=Create(
name=name,
region=region,
source=source,
destination=destination,
aws=V1AwsDataConnection(
region=region,
source=source,
destination=destination,
secret_arn_name=secret_arn_name,
),
),
project_id=project_id,
)
Expand All @@ -86,8 +103,8 @@ def connect_data(
# project_id=project_id,
# id=response.id,
# )
# print(response)
except Exception:
_error_and_exit("The data connection creation failed.")
except lightning_cloud.openapi.rest.ApiException as e:
message = ast.literal_eval(e.body.decode("utf-8"))["message"]
_error_and_exit(f"The data connection creation failed. Message: {message}")

rich.print(f"[green]Succeeded[/green]: You have created a new data connection {name}.")
13 changes: 6 additions & 7 deletions tests/tests_app/cli/test_connect_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import sys
from unittest.mock import MagicMock

import pytest

from lightning.app.cli.connect import data


@pytest.mark.skipif(True, reason="In progress")
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
def test_connect_data_no_project(monkeypatch):

from lightning_cloud.openapi import V1ListMembershipsResponse, V1Membership
Expand All @@ -26,10 +27,10 @@ def test_connect_data_no_project(monkeypatch):
_get_project.assert_called()


@pytest.mark.skipif(True, reason="In progress")
@pytest.mark.skipif(sys.platform == "win32", reason="lightning connect data isn't supported on windows")
def test_connect_data(monkeypatch):

from lightning_cloud.openapi import ProjectIdDataConnectionsBody, V1ListMembershipsResponse, V1Membership
from lightning_cloud.openapi import Create, V1AwsDataConnection, V1ListMembershipsResponse, V1Membership

client = MagicMock()
client.projects_service_list_memberships.return_value = V1ListMembershipsResponse(
Expand All @@ -53,10 +54,8 @@ def test_connect_data(monkeypatch):

client.data_connection_service_create_data_connection.assert_called_with(
project_id="project-id-0",
body=ProjectIdDataConnectionsBody(
destination="",
region="us-east-1",
body=Create(
name="imagenet",
source="s3://imagenet",
aws=V1AwsDataConnection(destination="", region="us-east-1", source="s3://imagenet", secret_arn_name=""),
),
)