Skip to content

Commit 592b126

Browse files
authored
[App] PoC: Add support for Request (#16047)
1 parent 005b6f2 commit 592b126

File tree

4 files changed

+116
-5
lines changed

4 files changed

+116
-5
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ celerybeat-schedule
110110

111111
# dotenv
112112
.env
113-
.env_staging
114-
.env_local
113+
.env.staging
114+
.env.local
115115

116116
# virtualenv
117117
.venv

src/lightning_app/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035))
1414

1515

16+
- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))
17+
18+
1619
### Changed
1720

1821
-

src/lightning_app/api/http_methods.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import inspect
33
import time
44
from copy import deepcopy
5+
from dataclasses import dataclass
56
from functools import wraps
67
from multiprocessing import Queue
78
from typing import Any, Callable, Dict, List, Optional
89
from uuid import uuid4
910

10-
from fastapi import FastAPI, HTTPException
11+
from fastapi import FastAPI, HTTPException, Request
12+
from lightning_utilities.core.apply_func import apply_to_collection
1113

1214
from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
1315
from lightning_app.utilities.app_helpers import Logger
@@ -19,6 +21,77 @@ def _signature_proxy_function():
1921
pass
2022

2123

24+
@dataclass
25+
class _FastApiMockRequest:
26+
"""This class is meant to mock FastAPI Request class that isn't pickle-able.
27+
28+
If a user relies on FastAPI Request annotation, the Lightning framework
29+
patches the annotation before pickling and replace them right after.
30+
31+
Finally, the FastAPI request is converted back to the _FastApiMockRequest
32+
before being delivered to the users.
33+
34+
Example:
35+
36+
import lightning as L
37+
from fastapi import Request
38+
from lightning.app.api import Post
39+
40+
class Flow(L.LightningFlow):
41+
42+
def request(self, request: Request) -> OutputRequestModel:
43+
...
44+
45+
def configure_api(self):
46+
return [Post("/api/v1/request", self.request)]
47+
"""
48+
49+
_body: Optional[str] = None
50+
_json: Optional[str] = None
51+
_method: Optional[str] = None
52+
_headers: Optional[Dict] = None
53+
54+
@property
55+
def receive(self):
56+
raise NotImplementedError
57+
58+
@property
59+
def method(self):
60+
raise self._method
61+
62+
@property
63+
def headers(self):
64+
return self._headers
65+
66+
def body(self):
67+
return self._body
68+
69+
def json(self):
70+
return self._json
71+
72+
def stream(self):
73+
raise NotImplementedError
74+
75+
def form(self):
76+
raise NotImplementedError
77+
78+
def close(self):
79+
raise NotImplementedError
80+
81+
def is_disconnected(self):
82+
raise NotImplementedError
83+
84+
85+
async def _mock_fastapi_request(request: Request):
86+
# TODO: Add more requests parameters.
87+
return _FastApiMockRequest(
88+
_body=await request.body(),
89+
_json=await request.json(),
90+
_headers=request.headers,
91+
_method=request.method,
92+
)
93+
94+
2295
class _HttpMethod:
2396
def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs):
2497
"""This class is used to inject user defined methods within the App Rest API.
@@ -34,6 +107,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
34107
self.method_annotations = method.__annotations__
35108
# TODO: Validate the signature contains only pydantic models.
36109
self.method_signature = inspect.signature(method)
110+
37111
if not self.attached_to_flow:
38112
self.component_name = method.__name__
39113
self.method = method
@@ -43,10 +117,16 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
43117
self.timeout = timeout
44118
self.kwargs = kwargs
45119

120+
# Enable the users to rely on FastAPI annotation typing with Request.
121+
# Note: Only a part of the Request functionatilities are supported.
122+
self._patch_fast_api_request()
123+
46124
def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
47125
# 1: Get the route associated with the http method.
48126
route = getattr(app, self.__class__.__name__.lower())
49127

128+
self._unpatch_fast_api_request()
129+
50130
# 2: Create a proxy function with the signature of the wrapped method.
51131
fn = deepcopy(_signature_proxy_function)
52132
fn.__annotations__ = self.method_annotations
@@ -69,6 +149,11 @@ async def _handle_request(*args, **kwargs):
69149
@wraps(_signature_proxy_function)
70150
async def _handle_request(*args, **kwargs):
71151
async def fn(*args, **kwargs):
152+
args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request)
153+
for k, v in kwargs.items():
154+
if hasattr(v, "__await__"):
155+
kwargs[k] = await v
156+
72157
request_id = str(uuid4()).split("-")[0]
73158
logger.debug(f"Processing request {request_id} for route: {self.route}")
74159
request_queue.put(
@@ -101,6 +186,26 @@ async def fn(*args, **kwargs):
101186
# 4: Register the user provided route to the Rest API.
102187
route(self.route, **self.kwargs)(_handle_request)
103188

189+
def _patch_fast_api_request(self):
190+
"""This function replaces signature annotation for Request with its mock."""
191+
for k, v in self.method_annotations.items():
192+
if v == Request:
193+
self.method_annotations[k] = _FastApiMockRequest
194+
195+
for v in self.method_signature.parameters.values():
196+
if v._annotation == Request:
197+
v._annotation = _FastApiMockRequest
198+
199+
def _unpatch_fast_api_request(self):
200+
"""This function replaces back signature annotation to fastapi Request."""
201+
for k, v in self.method_annotations.items():
202+
if v == _FastApiMockRequest:
203+
self.method_annotations[k] = Request
204+
205+
for v in self.method_signature.parameters.values():
206+
if v._annotation == _FastApiMockRequest:
207+
v._annotation = Request
208+
104209

105210
class Post(_HttpMethod):
106211
pass

tests/tests_app/core/test_lightning_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313
import requests
1414
from deepdiff import DeepDiff, Delta
15-
from fastapi import HTTPException
15+
from fastapi import HTTPException, Request
1616
from httpx import AsyncClient
1717
from pydantic import BaseModel
1818

@@ -479,10 +479,13 @@ def run(self):
479479
if self.counter == 501:
480480
self._exit()
481481

482-
def request(self, config: InputRequestModel) -> OutputRequestModel:
482+
def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel:
483483
self.counter += 1
484484
if config.index % 5 == 0:
485485
raise HTTPException(status_code=400, detail="HERE")
486+
assert request.body()
487+
assert request.json()
488+
assert request.headers
486489
return OutputRequestModel(name=config.name, counter=self.counter)
487490

488491
def configure_api(self):

0 commit comments

Comments
 (0)