22
22
23
23
import click
24
24
import requests
25
- import rich
26
25
import urllib3
27
- from lightning_cloud .openapi import Externalv1LightningappInstance , IdArtifactsBody , V1CloudSpace
26
+ from lightning_cloud .openapi import Externalv1LightningappInstance , ProjectIdStorageBody , V1CloudSpace
28
27
from rich .live import Live
29
28
from rich .progress import BarColumn , DownloadColumn , Progress , Task , TextColumn
30
29
from rich .spinner import Spinner
34
33
from lightning_app .cli .commands .pwd import _pwd
35
34
from lightning_app .source_code import FileUploader
36
35
from lightning_app .utilities .app_helpers import Logger
36
+ from lightning_app .utilities .cli_helpers import _error_and_exit
37
37
from lightning_app .utilities .network import LightningClient
38
38
39
39
logger = Logger (__name__ )
@@ -55,7 +55,7 @@ def cp(src_path: str, dst_path: str, r: bool = False, recursive: bool = False) -
55
55
pwd = _pwd ()
56
56
57
57
if pwd == "/" or len (pwd .split ("/" )) == 1 :
58
- return _error_and_exit ("Uploading files at the project level isn't supported yet." )
58
+ return _error_and_exit ("Uploading files at the project level isn't allowed yet." )
59
59
60
60
client = LightningClient ()
61
61
@@ -74,10 +74,18 @@ def cp(src_path: str, dst_path: str, r: bool = False, recursive: bool = False) -
74
74
75
75
76
76
def _upload_files (live , client : LightningClient , local_src : str , remote_dst : str , pwd : str ) -> str :
77
+ remote_splits = [split for split in remote_dst .split ("/" ) if split != "" ]
78
+ remote_dst = os .path .join (* remote_splits )
79
+
77
80
if not os .path .exists (local_src ):
78
81
return _error_and_exit (f"The provided source path { local_src } doesn't exist." )
79
82
80
- project_id , app_id = _get_project_app_ids (pwd )
83
+ lit_resource = None
84
+
85
+ if len (remote_splits ) > 1 :
86
+ project_id , lit_resource = _get_project_id_and_resource (pwd )
87
+ else :
88
+ project_id = _get_project_id_from_name (remote_dst )
81
89
82
90
local_src = Path (local_src ).resolve ()
83
91
upload_paths = []
@@ -91,21 +99,34 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
91
99
92
100
upload_urls = []
93
101
102
+ clusters = client .projects_service_list_project_cluster_bindings (project_id )
103
+
94
104
for upload_path in upload_paths :
95
- filename = str (upload_path ).replace (str (os .getcwd ()), "" )[1 :]
96
- response = client .lightningapp_instance_service_upload_lightningapp_instance_artifact (
97
- project_id = project_id ,
98
- id = app_id ,
99
- body = IdArtifactsBody (filename ),
100
- async_req = True ,
101
- )
102
- upload_urls .append (response )
105
+ for cluster in clusters .clusters :
106
+ filename = str (upload_path ).replace (str (os .getcwd ()), "" )[1 :]
107
+ if lit_resource :
108
+ filename = _get_prefix (os .path .join (remote_dst , filename ), lit_resource )
109
+ else :
110
+ filename = "/" + filename
111
+
112
+ response = client .lightningapp_instance_service_upload_project_artifact (
113
+ project_id = project_id ,
114
+ body = ProjectIdStorageBody (cluster_id = cluster .cluster_id , filename = filename ),
115
+ async_req = True ,
116
+ )
117
+ upload_urls .append (response )
118
+
119
+ upload_urls = [upload_url .get ().upload_url for upload_url in upload_urls ]
103
120
104
121
live .stop ()
105
122
123
+ if not upload_paths :
124
+ print ("There were no files to upload." )
125
+ return
126
+
106
127
progress = _get_progress_bar ()
107
128
108
- total_size = sum ([Path (path ).stat ().st_size for path in upload_paths ])
129
+ total_size = sum ([Path (path ).stat ().st_size for path in upload_paths ]) // max ( len ( clusters . clusters ), 1 )
109
130
task_id = progress .add_task ("upload" , filename = "" , total = total_size )
110
131
111
132
progress .start ()
@@ -126,7 +147,7 @@ def _upload_files(live, client: LightningClient, local_src: str, remote_dst: str
126
147
def _upload (source_file : str , presigned_url : ApplyResult , progress : Progress , task_id : Task ) -> Optional [Exception ]:
127
148
source_file = Path (source_file )
128
149
file_uploader = FileUploader (
129
- presigned_url . get (). upload_url ,
150
+ presigned_url ,
130
151
source_file ,
131
152
total_size = None ,
132
153
name = str (source_file ),
@@ -143,13 +164,13 @@ def _download_files(live, client, remote_src: str, local_dst: str, pwd: str):
143
164
download_urls = []
144
165
total_size = []
145
166
146
- prefix = _get_prefix ("/" .join (pwd .split ("/" )[3 :]), lit_resource )
167
+ prefix = _get_prefix ("/" .join (pwd .split ("/" )[3 :]), lit_resource ) + "/"
147
168
148
169
for artifact in _collect_artifacts (client , project_id , prefix , include_download_url = True ):
149
170
path = os .path .join (local_dst , artifact .filename .replace (remote_src , "" ))
150
171
path = Path (path ).resolve ()
151
172
os .makedirs (path .parent , exist_ok = True )
152
- download_paths .append (Path ( path ). resolve () )
173
+ download_paths .append (path )
153
174
download_urls .append (artifact .url )
154
175
total_size .append (int (artifact .size_bytes ))
155
176
@@ -182,14 +203,17 @@ def _download_file(path: str, url: str, progress: Progress, task_id: Task) -> No
182
203
# Disable warning about making an insecure request
183
204
urllib3 .disable_warnings (urllib3 .exceptions .InsecureRequestWarning )
184
205
185
- request = requests .get (url , stream = True , verify = False )
206
+ try :
207
+ request = requests .get (url , stream = True , verify = False )
186
208
187
- chunk_size = 1024
209
+ chunk_size = 1024
188
210
189
- with open (path , "wb" ) as fp :
190
- for chunk in request .iter_content (chunk_size = chunk_size ):
191
- fp .write (chunk ) # type: ignore
192
- progress .update (task_id , advance = len (chunk ))
211
+ with open (path , "wb" ) as fp :
212
+ for chunk in request .iter_content (chunk_size = chunk_size ):
213
+ fp .write (chunk ) # type: ignore
214
+ progress .update (task_id , advance = len (chunk ))
215
+ except ConnectionError :
216
+ pass
193
217
194
218
195
219
def _sanitize_path (path : str , pwd : str ) -> Tuple [str , bool ]:
@@ -211,29 +235,6 @@ def _remove_remote(path: str) -> str:
211
235
return path .replace ("r:" , "" ).replace ("remote:" , "" )
212
236
213
237
214
- def _error_and_exit (msg : str ) -> str :
215
- rich .print (f"[red]ERROR[/red]: { msg } " )
216
- sys .exit (0 )
217
-
218
-
219
- # TODO: To be removed when upload is supported for CloudSpaces.
220
- def _get_project_app_ids (pwd : str ) -> Tuple [str , str ]:
221
- """Convert a root path to a project id and app id."""
222
- # TODO: Handle project level
223
- project_name , app_name , * _ = pwd .split ("/" )[1 :3 ]
224
- client = LightningClient ()
225
- projects = client .projects_service_list_memberships ()
226
- project_id = [project .project_id for project in projects .memberships if project .name == project_name ][0 ]
227
- client = LightningClient ()
228
- lit_apps = client .lightningapp_instance_service_list_lightningapp_instances (project_id = project_id ).lightningapps
229
- lit_apps = [lit_app for lit_app in lit_apps if lit_app .name == app_name ]
230
- if len (lit_apps ) != 1 :
231
- print (f"ERROR: There isn't any Lightning App matching the name { app_name } ." )
232
- sys .exit (0 )
233
- lit_app = lit_apps [0 ]
234
- return project_id , lit_app .id
235
-
236
-
237
238
def _get_project_id_and_resource (pwd : str ) -> Tuple [str , Union [Externalv1LightningappInstance , V1CloudSpace ]]:
238
239
"""Convert a root path to a project id and app id."""
239
240
# TODO: Handle project level
@@ -263,6 +264,13 @@ def _get_project_id_and_resource(pwd: str) -> Tuple[str, Union[Externalv1Lightni
263
264
return project_id , lit_ressources [0 ]
264
265
265
266
267
+ def _get_project_id_from_name (project_name : str ) -> str :
268
+ # 1. Collect the projects of the user
269
+ client = LightningClient ()
270
+ projects = client .projects_service_list_memberships ()
271
+ return [project .project_id for project in projects .memberships if project .name == project_name ][0 ]
272
+
273
+
266
274
def _get_progress_bar ():
267
275
return Progress (
268
276
TextColumn ("[bold blue]{task.description}" , justify = "left" ),
0 commit comments