Skip to content

Commit 9cea90e

Browse files
authored
[Frontend] Add /classify endpoint (#17032)
Signed-off-by: Frieda (Jingying) Huang <[email protected]>
1 parent d1110f5 commit 9cea90e

File tree

9 files changed

+961
-162
lines changed

9 files changed

+961
-162
lines changed

docs/source/models/pooling_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints tha
140140

141141
- [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
142142
- [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models.
143+
- [Classification API](#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models.
143144
- [Score API](#score-api) is similar to `LLM.score` for cross-encoder models.
144145

145146
## Matryoshka Embeddings

docs/source/serving/openai_compatible_server.md

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ In addition, we have the following custom APIs:
6161
- Applicable to any model with a tokenizer.
6262
- [Pooling API](#pooling-api) (`/pooling`)
6363
- Applicable to all [pooling models](../models/pooling_models.md).
64+
- [Classification API](#classification-api) (`/classify`)
65+
- Only applicable to [classification models](../models/pooling_models.md) (`--task classify`).
6466
- [Score API](#score-api) (`/score`)
6567
- Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
6668
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
@@ -443,6 +445,130 @@ The input format is the same as [Embeddings API](#embeddings-api), but the outpu
443445

444446
Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
445447

448+
(classification-api)=
449+
450+
### Classification API
451+
452+
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
453+
454+
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
455+
456+
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
457+
458+
#### Example Requests
459+
460+
You can classify multiple texts by passing an array of strings:
461+
462+
Request:
463+
464+
```bash
465+
curl -v "http://127.0.0.1:8000/classify" \
466+
-H "Content-Type: application/json" \
467+
-d '{
468+
"model": "jason9693/Qwen2.5-1.5B-apeach",
469+
"input": [
470+
"Loved the new café—coffee was great.",
471+
"This update broke everything. Frustrating."
472+
]
473+
}'
474+
```
475+
476+
Response:
477+
478+
```bash
479+
{
480+
"id": "classify-7c87cac407b749a6935d8c7ce2a8fba2",
481+
"object": "list",
482+
"created": 1745383065,
483+
"model": "jason9693/Qwen2.5-1.5B-apeach",
484+
"data": [
485+
{
486+
"index": 0,
487+
"label": "Default",
488+
"probs": [
489+
0.565970778465271,
490+
0.4340292513370514
491+
],
492+
"num_classes": 2
493+
},
494+
{
495+
"index": 1,
496+
"label": "Spoiled",
497+
"probs": [
498+
0.26448777318000793,
499+
0.7355121970176697
500+
],
501+
"num_classes": 2
502+
}
503+
],
504+
"usage": {
505+
"prompt_tokens": 20,
506+
"total_tokens": 20,
507+
"completion_tokens": 0,
508+
"prompt_tokens_details": null
509+
}
510+
}
511+
```
512+
513+
You can also pass a string directly to the `input` field:
514+
515+
Request:
516+
517+
```bash
518+
curl -v "http://127.0.0.1:8000/classify" \
519+
-H "Content-Type: application/json" \
520+
-d '{
521+
"model": "jason9693/Qwen2.5-1.5B-apeach",
522+
"input": "Loved the new café—coffee was great."
523+
}'
524+
```
525+
526+
Response:
527+
528+
```bash
529+
{
530+
"id": "classify-9bf17f2847b046c7b2d5495f4b4f9682",
531+
"object": "list",
532+
"created": 1745383213,
533+
"model": "jason9693/Qwen2.5-1.5B-apeach",
534+
"data": [
535+
{
536+
"index": 0,
537+
"label": "Default",
538+
"probs": [
539+
0.565970778465271,
540+
0.4340292513370514
541+
],
542+
"num_classes": 2
543+
}
544+
],
545+
"usage": {
546+
"prompt_tokens": 10,
547+
"total_tokens": 10,
548+
"completion_tokens": 0,
549+
"prompt_tokens_details": null
550+
}
551+
}
552+
```
553+
554+
#### Extra parameters
555+
556+
The following [pooling parameters](#pooling-params) are supported.
557+
558+
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
559+
:language: python
560+
:start-after: begin-classification-pooling-params
561+
:end-before: end-classification-pooling-params
562+
:::
563+
564+
The following extra parameters are supported:
565+
566+
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
567+
:language: python
568+
:start-after: begin-classification-extra-params
569+
:end-before: end-classification-extra-params
570+
:::
571+
446572
(score-api)=
447573

448574
### Score API
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import argparse
4+
import pprint
5+
6+
import requests
7+
8+
9+
def post_http_request(payload: dict, api_url: str) -> requests.Response:
10+
headers = {"User-Agent": "Test Client"}
11+
response = requests.post(api_url, headers=headers, json=payload)
12+
return response
13+
14+
15+
def parse_args():
16+
parse = argparse.ArgumentParser()
17+
parse.add_argument("--host", type=str, default="localhost")
18+
parse.add_argument("--port", type=int, default=8000)
19+
parse.add_argument("--model",
20+
type=str,
21+
default="jason9693/Qwen2.5-1.5B-apeach")
22+
return parse.parse_args()
23+
24+
25+
def main(args):
26+
host = args.host
27+
port = args.port
28+
model_name = args.model
29+
30+
api_url = f"http://{host}:{port}/classify"
31+
prompts = [
32+
"Hello, my name is",
33+
"The president of the United States is",
34+
"The capital of France is",
35+
"The future of AI is",
36+
]
37+
38+
payload = {
39+
"model": model_name,
40+
"input": prompts,
41+
}
42+
43+
classify_response = post_http_request(payload=payload, api_url=api_url)
44+
pprint.pprint(classify_response.json())
45+
46+
47+
if __name__ == "__main__":
48+
args = parse_args()
49+
main(args)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import requests
5+
6+
from vllm.entrypoints.openai.protocol import ClassificationResponse
7+
8+
from ...utils import RemoteOpenAIServer
9+
10+
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
11+
DTYPE = "float32" # Use float32 to avoid NaN issue
12+
13+
14+
@pytest.fixture(scope="module")
15+
def server():
16+
args = [
17+
"--enforce-eager",
18+
"--max-model-len",
19+
"512",
20+
"--dtype",
21+
DTYPE,
22+
]
23+
24+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
25+
yield remote_server
26+
27+
28+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
29+
def test_single_input_classification(server: RemoteOpenAIServer,
30+
model_name: str):
31+
input_text = "This product was excellent and exceeded my expectations"
32+
33+
classification_response = requests.post(
34+
server.url_for("classify"),
35+
json={
36+
"model": model_name,
37+
"input": input_text
38+
},
39+
)
40+
41+
classification_response.raise_for_status()
42+
output = ClassificationResponse.model_validate(
43+
classification_response.json())
44+
45+
assert output.object == "list"
46+
assert output.model == MODEL_NAME
47+
assert len(output.data) == 1
48+
assert hasattr(output.data[0], "label")
49+
assert hasattr(output.data[0], "probs")
50+
51+
52+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
53+
def test_multiple_inputs_classification(server: RemoteOpenAIServer,
54+
model_name: str):
55+
input_texts = [
56+
"The product arrived on time and works perfectly",
57+
"I'm very satisfied with my purchase, would buy again",
58+
"The customer service was helpful and resolved my issue quickly",
59+
"This product broke after one week, terrible quality",
60+
"I'm very disappointed with this purchase, complete waste of money",
61+
"The customer service was rude and unhelpful",
62+
]
63+
64+
classification_response = requests.post(
65+
server.url_for("classify"),
66+
json={
67+
"model": model_name,
68+
"input": input_texts
69+
},
70+
)
71+
output = ClassificationResponse.model_validate(
72+
classification_response.json())
73+
74+
assert len(output.data) == len(input_texts)
75+
for i, item in enumerate(output.data):
76+
assert item.index == i
77+
assert hasattr(item, "label")
78+
assert hasattr(item, "probs")
79+
assert len(item.probs) == item.num_classes
80+
assert item.label in ["Default", "Spoiled"]
81+
82+
83+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
84+
def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
85+
long_text = "hello " * 600
86+
87+
classification_response = requests.post(
88+
server.url_for("classify"),
89+
json={
90+
"model": model_name,
91+
"input": long_text,
92+
"truncate_prompt_tokens": 5
93+
},
94+
)
95+
96+
classification_response.raise_for_status()
97+
output = ClassificationResponse.model_validate(
98+
classification_response.json())
99+
100+
assert len(output.data) == 1
101+
assert output.data[0].index == 0
102+
assert hasattr(output.data[0], "probs")
103+
assert output.usage.prompt_tokens == 5
104+
assert output.usage.total_tokens == 5
105+
106+
107+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
108+
def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
109+
model_name: str):
110+
classification_response = requests.post(
111+
server.url_for("classify"),
112+
json={
113+
"model": model_name,
114+
"input": "test",
115+
"truncate_prompt_tokens": 513
116+
},
117+
)
118+
119+
error = classification_response.json()
120+
assert classification_response.status_code == 400
121+
assert error["object"] == "error"
122+
assert "truncate_prompt_tokens" in error["message"]
123+
124+
125+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
126+
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
127+
classification_response = requests.post(
128+
server.url_for("classify"),
129+
json={
130+
"model": model_name,
131+
"input": ""
132+
},
133+
)
134+
135+
error = classification_response.json()
136+
assert classification_response.status_code == 400
137+
assert error["object"] == "error"
138+
139+
140+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
141+
def test_batch_classification_empty_list(server: RemoteOpenAIServer,
142+
model_name: str):
143+
classification_response = requests.post(
144+
server.url_for("classify"),
145+
json={
146+
"model": model_name,
147+
"input": []
148+
},
149+
)
150+
classification_response.raise_for_status()
151+
output = ClassificationResponse.model_validate(
152+
classification_response.json())
153+
154+
assert output.object == "list"
155+
assert isinstance(output.data, list)
156+
assert len(output.data) == 0

0 commit comments

Comments
 (0)