24
24
import tempfile
25
25
import threading
26
26
import time
27
+ import uuid
27
28
from argparse import ArgumentParser , Namespace
28
29
from collections .abc import AsyncGenerator , Generator , Iterable
29
30
from contextlib import asynccontextmanager
@@ -132,7 +133,6 @@ class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreamin
132
133
"""
133
134
134
135
generation_config : str
135
- request_id : str
136
136
137
137
class TransformersTranscriptionCreateParams (TranscriptionCreateParamsBase , total = False ):
138
138
"""
@@ -211,6 +211,8 @@ class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total
211
211
}
212
212
_MODELS_WITH_TOOL_SUPPORT = list (_TOOL_CALL_TOKENS .keys ())
213
213
214
+ X_REQUEST_ID = "x-request-id"
215
+
214
216
215
217
class Modality (enum .Enum ):
216
218
LLM = "LLM"
@@ -688,14 +690,16 @@ async def lifespan(app: FastAPI):
688
690
"CORS allow origin is set to `*`. This is not recommended for production environments."
689
691
)
690
692
693
+ from fastapi import Request
694
+
691
695
@app .post ("/v1/chat/completions" )
692
- def chat_completion (request : dict ):
693
- self .validate_chat_completion_request (request = request )
696
+ def chat_completion (request : Request , body : dict ):
697
+ self .validate_chat_completion_request (request = body )
694
698
695
699
if self .use_continuous_batching :
696
- output = self .continuous_batching_chat_completion (request )
700
+ output = self .continuous_batching_chat_completion (body , request . state . request_id )
697
701
else :
698
- output = self .generate_chat_completion (request )
702
+ output = self .generate_chat_completion (body )
699
703
return StreamingResponse (output , media_type = "text/event-stream" )
700
704
701
705
@app .post ("/v1/responses" )
@@ -705,8 +709,6 @@ def responses(request: dict):
705
709
output = self .generate_response (request )
706
710
return StreamingResponse (output , media_type = "text/event-stream" )
707
711
708
- from fastapi import Request
709
-
710
712
@app .post ("/v1/audio/transcriptions" )
711
713
async def audio_transcriptions (request : Request ):
712
714
# Parses the multipart/form-data request into the request format used by other endpoints
@@ -734,6 +736,14 @@ def get_all_models():
734
736
def healthcheck ():
735
737
return JSONResponse ({"status" : "ok" })
736
738
739
+ @app .middleware ("http" )
740
+ async def get_or_set_request_id (request : Request , call_next ):
741
+ request_id = request .headers .get (X_REQUEST_ID ) or str (uuid .uuid4 ())
742
+ request .state .request_id = request_id
743
+ response = await call_next (request )
744
+ response .headers [X_REQUEST_ID ] = request_id
745
+ return response
746
+
737
747
uvicorn .run (app , host = self .args .host , port = self .args .port , log_level = self .args .log_level )
738
748
739
749
@functools .cache
@@ -782,7 +792,7 @@ def get_gen_models(self) -> list[dict[str, any]]:
782
792
for model in model_infos
783
793
]
784
794
785
- def continuous_batching_chat_completion (self , req : dict ) -> AsyncGenerator [str , None ]:
795
+ def continuous_batching_chat_completion (self , req : dict , request_id : str ) -> AsyncGenerator [str , None ]:
786
796
"""
787
797
Generates an OpenAI Chat Completion using continuous batching.
788
798
@@ -858,22 +868,21 @@ def stream_chat_completion(request_id, decode_stream):
858
868
self .running_continuous_batching_manager .cancel_request (request_id )
859
869
yield f'data: {{"error": "{ str (e )} "}}'
860
870
861
- async def cancellation_wrapper (_inputs ):
862
- request_id = None
871
+ async def cancellation_wrapper (_inputs , request_id ):
863
872
try :
864
873
decode_stream = DecodeStream (_inputs .tolist (), False )
874
+ # XXX: using returned request_id as safety in case it is None
865
875
request_id = self .running_continuous_batching_manager .add_request (
866
- _inputs , request_id = req . get ( " request_id" ) , max_new_tokens = generation_config .max_new_tokens
876
+ _inputs , request_id = request_id , max_new_tokens = generation_config .max_new_tokens
867
877
)
868
878
for chunk in stream_chat_completion (request_id , decode_stream ):
869
879
yield chunk
870
880
await asyncio .sleep (0 ) # Yield control to the event loop to check for cancellations
871
881
except asyncio .CancelledError :
872
- if request_id is not None :
873
- self .running_continuous_batching_manager .cancel_request (request_id )
874
- logger .warning (f"Request { request_id } was cancelled." )
882
+ self .running_continuous_batching_manager .cancel_request (request_id )
883
+ logger .warning (f"Request { request_id } was cancelled." )
875
884
876
- return cancellation_wrapper (inputs [0 ])
885
+ return cancellation_wrapper (inputs [0 ], request_id )
877
886
878
887
@staticmethod
879
888
def get_model_modality (model : "PreTrainedModel" ) -> Modality :
0 commit comments