2
2
import inspect
3
3
import time
4
4
from copy import deepcopy
5
+ from dataclasses import dataclass
5
6
from functools import wraps
6
7
from multiprocessing import Queue
7
8
from typing import Any , Callable , Dict , List , Optional
8
9
from uuid import uuid4
9
10
10
- from fastapi import FastAPI , HTTPException
11
+ from fastapi import FastAPI , HTTPException , Request
12
+ from lightning_utilities .core .apply_func import apply_to_collection
11
13
12
14
from lightning_app .api .request_types import _APIRequest , _CommandRequest , _RequestResponse
13
15
from lightning_app .utilities .app_helpers import Logger
@@ -19,6 +21,77 @@ def _signature_proxy_function():
19
21
pass
20
22
21
23
24
+ @dataclass
25
+ class _FastApiMockRequest :
26
+ """This class is meant to mock FastAPI Request class that isn't pickle-able.
27
+
28
+ If a user relies on FastAPI Request annotation, the Lightning framework
29
+ patches the annotation before pickling and replace them right after.
30
+
31
+ Finally, the FastAPI request is converted back to the _FastApiMockRequest
32
+ before being delivered to the users.
33
+
34
+ Example:
35
+
36
+ import lightning as L
37
+ from fastapi import Request
38
+ from lightning.app.api import Post
39
+
40
+ class Flow(L.LightningFlow):
41
+
42
+ def request(self, request: Request) -> OutputRequestModel:
43
+ ...
44
+
45
+ def configure_api(self):
46
+ return [Post("/api/v1/request", self.request)]
47
+ """
48
+
49
+ _body : Optional [str ] = None
50
+ _json : Optional [str ] = None
51
+ _method : Optional [str ] = None
52
+ _headers : Optional [Dict ] = None
53
+
54
+ @property
55
+ def receive (self ):
56
+ raise NotImplementedError
57
+
58
+ @property
59
+ def method (self ):
60
+ raise self ._method
61
+
62
+ @property
63
+ def headers (self ):
64
+ return self ._headers
65
+
66
+ def body (self ):
67
+ return self ._body
68
+
69
+ def json (self ):
70
+ return self ._json
71
+
72
+ def stream (self ):
73
+ raise NotImplementedError
74
+
75
+ def form (self ):
76
+ raise NotImplementedError
77
+
78
+ def close (self ):
79
+ raise NotImplementedError
80
+
81
+ def is_disconnected (self ):
82
+ raise NotImplementedError
83
+
84
+
85
+ async def _mock_fastapi_request (request : Request ):
86
+ # TODO: Add more requests parameters.
87
+ return _FastApiMockRequest (
88
+ _body = await request .body (),
89
+ _json = await request .json (),
90
+ _headers = request .headers ,
91
+ _method = request .method ,
92
+ )
93
+
94
+
22
95
class _HttpMethod :
23
96
def __init__ (self , route : str , method : Callable , method_name : Optional [str ] = None , timeout : int = 30 , ** kwargs ):
24
97
"""This class is used to inject user defined methods within the App Rest API.
@@ -34,6 +107,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
34
107
self .method_annotations = method .__annotations__
35
108
# TODO: Validate the signature contains only pydantic models.
36
109
self .method_signature = inspect .signature (method )
110
+
37
111
if not self .attached_to_flow :
38
112
self .component_name = method .__name__
39
113
self .method = method
@@ -43,10 +117,16 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
43
117
self .timeout = timeout
44
118
self .kwargs = kwargs
45
119
120
+ # Enable the users to rely on FastAPI annotation typing with Request.
121
+ # Note: Only a part of the Request functionatilities are supported.
122
+ self ._patch_fast_api_request ()
123
+
46
124
def add_route (self , app : FastAPI , request_queue : Queue , responses_store : Dict [str , Any ]) -> None :
47
125
# 1: Get the route associated with the http method.
48
126
route = getattr (app , self .__class__ .__name__ .lower ())
49
127
128
+ self ._unpatch_fast_api_request ()
129
+
50
130
# 2: Create a proxy function with the signature of the wrapped method.
51
131
fn = deepcopy (_signature_proxy_function )
52
132
fn .__annotations__ = self .method_annotations
@@ -69,6 +149,11 @@ async def _handle_request(*args, **kwargs):
69
149
@wraps (_signature_proxy_function )
70
150
async def _handle_request (* args , ** kwargs ):
71
151
async def fn (* args , ** kwargs ):
152
+ args , kwargs = apply_to_collection ((args , kwargs ), Request , _mock_fastapi_request )
153
+ for k , v in kwargs .items ():
154
+ if hasattr (v , "__await__" ):
155
+ kwargs [k ] = await v
156
+
72
157
request_id = str (uuid4 ()).split ("-" )[0 ]
73
158
logger .debug (f"Processing request { request_id } for route: { self .route } " )
74
159
request_queue .put (
@@ -101,6 +186,26 @@ async def fn(*args, **kwargs):
101
186
# 4: Register the user provided route to the Rest API.
102
187
route (self .route , ** self .kwargs )(_handle_request )
103
188
189
+ def _patch_fast_api_request (self ):
190
+ """This function replaces signature annotation for Request with its mock."""
191
+ for k , v in self .method_annotations .items ():
192
+ if v == Request :
193
+ self .method_annotations [k ] = _FastApiMockRequest
194
+
195
+ for v in self .method_signature .parameters .values ():
196
+ if v ._annotation == Request :
197
+ v ._annotation = _FastApiMockRequest
198
+
199
+ def _unpatch_fast_api_request (self ):
200
+ """This function replaces back signature annotation to fastapi Request."""
201
+ for k , v in self .method_annotations .items ():
202
+ if v == _FastApiMockRequest :
203
+ self .method_annotations [k ] = Request
204
+
205
+ for v in self .method_signature .parameters .values ():
206
+ if v ._annotation == _FastApiMockRequest :
207
+ v ._annotation = Request
208
+
104
209
105
210
class Post (_HttpMethod ):
106
211
pass
0 commit comments