Skip to content

Commit ad99399

Browse files
authored
chore: stream errors while streaming SSE (#6160)
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent e6ebfd3 commit ad99399

File tree

2 files changed

+105
-39
lines changed

2 files changed

+105
-39
lines changed

core/http/endpoints/openai/chat.go

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
3131
var id, textContentToReturn string
3232
var created int
3333

34-
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
34+
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
3535
initialMessage := schema.OpenAIResponse{
3636
ID: id,
3737
Created: created,
@@ -41,7 +41,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
4141
}
4242
responses <- initialMessage
4343

44-
ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
44+
_, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool {
4545
usage := schema.OpenAIUsage{
4646
PromptTokens: tokenUsage.Prompt,
4747
CompletionTokens: tokenUsage.Completion,
@@ -65,16 +65,19 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
6565
return true
6666
})
6767
close(responses)
68+
return err
6869
}
69-
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
70+
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
7071
result := ""
71-
_, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
72+
_, tokenUsage, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
7273
result += s
7374
// TODO: Change generated BNF grammar to be compliant with the schema so we can
7475
// stream the result token by token here.
7576
return true
7677
})
77-
78+
if err != nil {
79+
return err
80+
}
7881
textContentToReturn = functions.ParseTextContent(result, config.FunctionsConfig)
7982
result = functions.CleanupLLMResult(result, config.FunctionsConfig)
8083
functionResults := functions.ParseFunctionCall(result, config.FunctionsConfig)
@@ -95,7 +98,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
9598
result, err := handleQuestion(config, cl, req, ml, startupOptions, functionResults, result, prompt)
9699
if err != nil {
97100
log.Error().Err(err).Msg("error handling question")
98-
return
101+
return err
99102
}
100103
usage := schema.OpenAIUsage{
101104
PromptTokens: tokenUsage.Prompt,
@@ -169,6 +172,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
169172
}
170173

171174
close(responses)
175+
return err
172176
}
173177

174178
return func(c *fiber.Ctx) error {
@@ -223,9 +227,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
223227
if err != nil {
224228
return err
225229
}
226-
if d.Type == "json_object" {
230+
231+
switch d.Type {
232+
case "json_object":
227233
input.Grammar = functions.JSONBNF
228-
} else if d.Type == "json_schema" {
234+
case "json_schema":
229235
d := schema.JsonSchemaRequest{}
230236
dat, err := json.Marshal(config.ResponseFormatMap)
231237
if err != nil {
@@ -326,31 +332,69 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
326332
c.Set("X-Correlation-ID", id)
327333

328334
responses := make(chan schema.OpenAIResponse)
335+
ended := make(chan error, 1)
329336

330-
if !shouldUseFn {
331-
go process(predInput, input, config, ml, responses, extraUsage)
332-
} else {
333-
go processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
334-
}
337+
go func() {
338+
if !shouldUseFn {
339+
ended <- process(predInput, input, config, ml, responses, extraUsage)
340+
} else {
341+
ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage)
342+
}
343+
}()
335344

336345
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
337346
usage := &schema.OpenAIUsage{}
338347
toolsCalled := false
339-
for ev := range responses {
340-
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
341-
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
342-
toolsCalled = true
343-
}
344-
var buf bytes.Buffer
345-
enc := json.NewEncoder(&buf)
346-
enc.Encode(ev)
347-
log.Debug().Msgf("Sending chunk: %s", buf.String())
348-
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
349-
if err != nil {
350-
log.Debug().Msgf("Sending chunk failed: %v", err)
351-
input.Cancel()
348+
349+
LOOP:
350+
for {
351+
select {
352+
case ev := <-responses:
353+
if len(ev.Choices) == 0 {
354+
log.Debug().Msgf("No choices in the response, skipping")
355+
continue
356+
}
357+
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
358+
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
359+
toolsCalled = true
360+
}
361+
var buf bytes.Buffer
362+
enc := json.NewEncoder(&buf)
363+
enc.Encode(ev)
364+
log.Debug().Msgf("Sending chunk: %s", buf.String())
365+
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
366+
if err != nil {
367+
log.Debug().Msgf("Sending chunk failed: %v", err)
368+
input.Cancel()
369+
}
370+
w.Flush()
371+
case err := <-ended:
372+
if err == nil {
373+
break LOOP
374+
}
375+
log.Error().Msgf("Stream ended with error: %v", err)
376+
377+
resp := &schema.OpenAIResponse{
378+
ID: id,
379+
Created: created,
380+
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
381+
Choices: []schema.Choice{
382+
{
383+
FinishReason: "stop",
384+
Index: 0,
385+
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
386+
}},
387+
Object: "chat.completion.chunk",
388+
Usage: *usage,
389+
}
390+
respData, _ := json.Marshal(resp)
391+
392+
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
393+
w.WriteString("data: [DONE]\n\n")
394+
w.Flush()
395+
396+
return
352397
}
353-
w.Flush()
354398
}
355399

356400
finishReason := "stop"
@@ -378,7 +422,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
378422
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
379423
w.WriteString("data: [DONE]\n\n")
380424
w.Flush()
425+
log.Debug().Msgf("Stream ended")
381426
}))
427+
382428
return nil
383429

384430
// no streaming mode

core/http/endpoints/openai/completion.go

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import (
3030
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
3131
created := int(time.Now().Unix())
3232

33-
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
33+
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
3434
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
3535
usage := schema.OpenAIUsage{
3636
PromptTokens: tokenUsage.Prompt,
@@ -59,8 +59,9 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
5959
responses <- resp
6060
return true
6161
}
62-
ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
62+
_, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback)
6363
close(responses)
64+
return err
6465
}
6566

6667
return func(c *fiber.Ctx) error {
@@ -121,18 +122,37 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
121122

122123
responses := make(chan schema.OpenAIResponse)
123124

124-
go process(id, predInput, input, config, ml, responses, extraUsage)
125+
ended := make(chan error)
126+
go func() {
127+
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
128+
}()
125129

126130
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
127131

128-
for ev := range responses {
129-
var buf bytes.Buffer
130-
enc := json.NewEncoder(&buf)
131-
enc.Encode(ev)
132-
133-
log.Debug().Msgf("Sending chunk: %s", buf.String())
134-
fmt.Fprintf(w, "data: %v\n", buf.String())
135-
w.Flush()
132+
LOOP:
133+
for {
134+
select {
135+
case ev := <-responses:
136+
if len(ev.Choices) == 0 {
137+
log.Debug().Msgf("No choices in the response, skipping")
138+
continue
139+
}
140+
var buf bytes.Buffer
141+
enc := json.NewEncoder(&buf)
142+
enc.Encode(ev)
143+
144+
log.Debug().Msgf("Sending chunk: %s", buf.String())
145+
fmt.Fprintf(w, "data: %v\n", buf.String())
146+
w.Flush()
147+
case err := <-ended:
148+
if err == nil {
149+
break LOOP
150+
}
151+
log.Error().Msgf("Stream ended with error: %v", err)
152+
fmt.Fprintf(w, "data: %v\n", "Internal error: "+err.Error())
153+
w.Flush()
154+
break LOOP
155+
}
136156
}
137157

138158
resp := &schema.OpenAIResponse{
@@ -153,7 +173,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
153173
w.WriteString("data: [DONE]\n\n")
154174
w.Flush()
155175
}))
156-
return nil
176+
return <-ended
157177
}
158178

159179
var result []schema.Choice

0 commit comments

Comments
 (0)