4
4
import pathlib
5
5
from dataclasses import dataclass
6
6
from enum import Enum
7
+ from time import sleep
7
8
from typing import Optional
8
9
from urllib .parse import urlencode
9
10
@@ -44,9 +45,11 @@ def __post_init__(self):
44
45
setattr (self , key .suffix , os .environ .get (key .value , None ))
45
46
46
47
self ._with_env_var = bool (self .user_id and self .api_key ) # used by authenticate method
47
- if self .api_key and not self .user_id :
48
+ if self ._with_env_var :
49
+ self .save ("" , self .user_id , self .api_key , self .user_id )
50
+ logger .info ("Credentials loaded from environment variables" )
51
+ elif self .api_key or self .user_id :
48
52
raise ValueError (
49
- f"{ Keys .USER_ID .value } is missing from env variables. "
50
53
"To use env vars for authentication both "
51
54
f"{ Keys .USER_ID .value } and { Keys .API_KEY .value } should be set."
52
55
)
@@ -135,7 +138,8 @@ def authenticate(self) -> Optional[str]:
135
138
136
139
137
140
class AuthServer :
138
- def get_auth_url (self , port : int ) -> str :
141
+ @staticmethod
142
+ def get_auth_url (port : int ) -> str :
139
143
redirect_uri = f"http://localhost:{ port } /login-complete"
140
144
params = urlencode (dict (redirectTo = redirect_uri ))
141
145
return f"{ get_lightning_cloud_url ()} /sign-in?{ params } "
@@ -144,6 +148,7 @@ def login_with_browser(self, auth: Auth) -> None:
144
148
app = FastAPI ()
145
149
port = find_free_network_port ()
146
150
url = self .get_auth_url (port )
151
+
147
152
try :
148
153
# check if server is reachable or catch any network errors
149
154
requests .head (url )
@@ -156,32 +161,42 @@ def login_with_browser(self, auth: Auth) -> None:
156
161
f"An error occurred with the request. Please report this issue to Lightning Team \n { e } " # E501
157
162
)
158
163
159
- logger .info (f"login started for lightning.ai, opening { url } " )
164
+ logger .info (
165
+ "\n Attempting to automatically open the login page in your default browser.\n "
166
+ 'If the browser does not open, navigate to the "Keys" tab on your Lightning AI profile page:\n \n '
167
+ f"{ get_lightning_cloud_url ()} /me/keys\n \n "
168
+ 'Copy the "Headless CLI Login" command, and execute it in your terminal.\n '
169
+ )
160
170
click .launch (url )
161
171
162
172
@app .get ("/login-complete" )
163
173
async def save_token (request : Request , token = "" , key = "" , user_id : str = Query ("" , alias = "userID" )):
164
- if token :
165
- auth .save (token = token , username = user_id , user_id = user_id , api_key = key )
166
- logger .info ("Authentication Successful" )
167
- else :
174
+ async def stop_server_once_request_is_done ():
175
+ while not await request .is_disconnected ():
176
+ sleep (0.25 )
177
+ server .should_exit = True
178
+
179
+ if not token :
168
180
logger .warn (
169
- "Authentication Failed. This is most likely because you're using an older version of the CLI. \n " # noqa E501
181
+ "Login Failed. This is most likely because you're using an older version of the CLI. \n " # noqa E501
170
182
"Please try to update the CLI or open an issue with this information \n " # E501
171
183
f"expected token in { request .query_params .items ()} "
172
184
)
185
+ return RedirectResponse (
186
+ url = f"{ get_lightning_cloud_url ()} /cli-login-failed" ,
187
+ background = BackgroundTask (stop_server_once_request_is_done ),
188
+ )
189
+
190
+ auth .save (token = token , username = user_id , user_id = user_id , api_key = key )
191
+ logger .info ("Login Successful" )
173
192
174
193
# Include the credentials in the redirect so that UI will also be logged in
175
194
params = urlencode (dict (token = token , key = key , userID = user_id ))
176
195
177
196
return RedirectResponse (
178
- url = f"{ get_lightning_cloud_url ()} /me/apps?{ params } " ,
179
- # The response background task is being executed right after the server finished writing the response
180
- background = BackgroundTask (stop_server ),
197
+ url = f"{ get_lightning_cloud_url ()} /cli-login-successful?{ params } " ,
198
+ background = BackgroundTask (stop_server_once_request_is_done ),
181
199
)
182
200
183
- def stop_server ():
184
- server .should_exit = True
185
-
186
201
server = uvicorn .Server (config = uvicorn .Config (app , port = port , log_level = "error" ))
187
202
server .run ()
0 commit comments