Skip to content

Commit bb45d36

Browse files
McPatategante
andauthored
refactor(serve): move request_id to headers (#40722)
* refactor(serve): move `request_id` to headers * fix(serve): typo in middleware fn name Co-authored-by: Joao Gante <[email protected]> --------- Co-authored-by: Joao Gante <[email protected]>
1 parent 12b8e10 commit bb45d36

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

src/transformers/commands/serving.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tempfile
2525
import threading
2626
import time
27+
import uuid
2728
from argparse import ArgumentParser, Namespace
2829
from collections.abc import AsyncGenerator, Generator, Iterable
2930
from contextlib import asynccontextmanager
@@ -132,7 +133,6 @@ class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreamin
132133
"""
133134

134135
generation_config: str
135-
request_id: str
136136

137137
class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
138138
"""
@@ -211,6 +211,8 @@ class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total
211211
}
212212
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
213213

214+
X_REQUEST_ID = "x-request-id"
215+
214216

215217
class Modality(enum.Enum):
216218
LLM = "LLM"
@@ -688,14 +690,16 @@ async def lifespan(app: FastAPI):
688690
"CORS allow origin is set to `*`. This is not recommended for production environments."
689691
)
690692

693+
from fastapi import Request
694+
691695
@app.post("/v1/chat/completions")
692-
def chat_completion(request: dict):
693-
self.validate_chat_completion_request(request=request)
696+
def chat_completion(request: Request, body: dict):
697+
self.validate_chat_completion_request(request=body)
694698

695699
if self.use_continuous_batching:
696-
output = self.continuous_batching_chat_completion(request)
700+
output = self.continuous_batching_chat_completion(body, request.state.request_id)
697701
else:
698-
output = self.generate_chat_completion(request)
702+
output = self.generate_chat_completion(body)
699703
return StreamingResponse(output, media_type="text/event-stream")
700704

701705
@app.post("/v1/responses")
@@ -705,8 +709,6 @@ def responses(request: dict):
705709
output = self.generate_response(request)
706710
return StreamingResponse(output, media_type="text/event-stream")
707711

708-
from fastapi import Request
709-
710712
@app.post("/v1/audio/transcriptions")
711713
async def audio_transcriptions(request: Request):
712714
# Parses the multipart/form-data request into the request format used by other endpoints
@@ -734,6 +736,14 @@ def get_all_models():
734736
def healthcheck():
735737
return JSONResponse({"status": "ok"})
736738

739+
@app.middleware("http")
740+
async def get_or_set_request_id(request: Request, call_next):
741+
request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
742+
request.state.request_id = request_id
743+
response = await call_next(request)
744+
response.headers[X_REQUEST_ID] = request_id
745+
return response
746+
737747
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
738748

739749
@functools.cache
@@ -782,7 +792,7 @@ def get_gen_models(self) -> list[dict[str, any]]:
782792
for model in model_infos
783793
]
784794

785-
def continuous_batching_chat_completion(self, req: dict) -> AsyncGenerator[str, None]:
795+
def continuous_batching_chat_completion(self, req: dict, request_id: str) -> AsyncGenerator[str, None]:
786796
"""
787797
Generates an OpenAI Chat Completion using continuous batching.
788798
@@ -858,22 +868,21 @@ def stream_chat_completion(request_id, decode_stream):
858868
self.running_continuous_batching_manager.cancel_request(request_id)
859869
yield f'data: {{"error": "{str(e)}"}}'
860870

861-
async def cancellation_wrapper(_inputs):
862-
request_id = None
871+
async def cancellation_wrapper(_inputs, request_id):
863872
try:
864873
decode_stream = DecodeStream(_inputs.tolist(), False)
874+
# XXX: using returned request_id as safety in case it is None
865875
request_id = self.running_continuous_batching_manager.add_request(
866-
_inputs, request_id=req.get("request_id"), max_new_tokens=generation_config.max_new_tokens
876+
_inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens
867877
)
868878
for chunk in stream_chat_completion(request_id, decode_stream):
869879
yield chunk
870880
await asyncio.sleep(0) # Yield control to the event loop to check for cancellations
871881
except asyncio.CancelledError:
872-
if request_id is not None:
873-
self.running_continuous_batching_manager.cancel_request(request_id)
874-
logger.warning(f"Request {request_id} was cancelled.")
882+
self.running_continuous_batching_manager.cancel_request(request_id)
883+
logger.warning(f"Request {request_id} was cancelled.")
875884

876-
return cancellation_wrapper(inputs[0])
885+
return cancellation_wrapper(inputs[0], request_id)
877886

878887
@staticmethod
879888
def get_model_modality(model: "PreTrainedModel") -> Modality:

tests/commands/test_serving.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -498,30 +498,45 @@ def _get_scheduler(serve_command):
498498
cbm = getattr(serve_command, "running_continuous_batching_manager", None)
499499
assert cbm is not None, "ServeCommand has no running_continuous_batching_manager"
500500
bp = getattr(cbm, "batch_processor", None)
501-
assert bp is not None, "CBM has no batch_processor"
501+
assert bp is not None, "running_continuous_batching_manager has no batch_processor"
502502
sched = getattr(bp, "scheduler", None)
503503
assert sched is not None, "batch_processor has no scheduler"
504504
return sched
505505

506506

507+
def _call_healthcheck(base_url: str):
508+
response = None
509+
retries = 10
510+
while retries > 0:
511+
try:
512+
response = requests.get(f"{base_url}/health")
513+
break
514+
except requests.exceptions.ConnectionError:
515+
time.sleep(0.1)
516+
retries -= 1
517+
return response
518+
519+
507520
def _open_stream_and_cancel(base_url: str, request_id: str):
508521
with requests.Session() as s:
509522
with s.post(
510523
f"{base_url}/v1/chat/completions",
524+
headers={"X-Request-ID": request_id},
511525
json={
512526
"model": "Qwen/Qwen2.5-0.5B-Instruct",
513527
"stream": True,
514528
"messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
515-
"request_id": request_id,
516529
},
517530
stream=True,
518531
timeout=30,
519532
) as resp:
520533
assert resp.status_code == 200
521534

522-
for _ in resp.iter_content(chunk_size=None):
523-
resp.close()
524-
break
535+
wait_for_n_chunks = 3
536+
for i, _ in enumerate(resp.iter_content(chunk_size=None)):
537+
if i >= wait_for_n_chunks:
538+
resp.close()
539+
break
525540

526541

527542
@slow # server startup time is slow on our push CI
@@ -598,6 +613,11 @@ def test_request_cancellation(self):
598613
base_url = f"http://127.0.0.1:{self.port}"
599614
request_id = "test-cancel"
600615

616+
# Ensure the server is up before sending a request
617+
response = _call_healthcheck(base_url)
618+
self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
619+
self.assertEqual(response.status_code, 200)
620+
601621
_open_stream_and_cancel(base_url, request_id)
602622

603623
scheduler = _get_scheduler(self.serve_command)
@@ -724,15 +744,7 @@ def setUpClass(cls):
724744

725745
def test_healthcheck(self):
726746
"""Tests that the healthcheck endpoint works."""
727-
response = None
728-
retries = 10
729-
while retries > 0:
730-
try:
731-
response = requests.get(f"http://localhost:{self.port}/health")
732-
break
733-
except requests.exceptions.ConnectionError:
734-
time.sleep(0.1)
735-
retries -= 1
747+
response = _call_healthcheck(f"http://localhost:{self.port}")
736748
self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
737749
self.assertEqual(response.status_code, 200)
738750
self.assertEqual(response.json(), {"status": "ok"})

0 commit comments

Comments
 (0)