@@ -31,7 +31,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
31
31
var id , textContentToReturn string
32
32
var created int
33
33
34
- process := func (s string , req * schema.OpenAIRequest , config * config.ModelConfig , loader * model.ModelLoader , responses chan schema.OpenAIResponse , extraUsage bool ) {
34
+ process := func (s string , req * schema.OpenAIRequest , config * config.ModelConfig , loader * model.ModelLoader , responses chan schema.OpenAIResponse , extraUsage bool ) error {
35
35
initialMessage := schema.OpenAIResponse {
36
36
ID : id ,
37
37
Created : created ,
@@ -41,7 +41,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
41
41
}
42
42
responses <- initialMessage
43
43
44
- ComputeChoices (req , s , config , cl , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , tokenUsage backend.TokenUsage ) bool {
44
+ _ , _ , err := ComputeChoices (req , s , config , cl , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , tokenUsage backend.TokenUsage ) bool {
45
45
usage := schema.OpenAIUsage {
46
46
PromptTokens : tokenUsage .Prompt ,
47
47
CompletionTokens : tokenUsage .Completion ,
@@ -65,16 +65,19 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
65
65
return true
66
66
})
67
67
close (responses )
68
+ return err
68
69
}
69
- processTools := func (noAction string , prompt string , req * schema.OpenAIRequest , config * config.ModelConfig , loader * model.ModelLoader , responses chan schema.OpenAIResponse , extraUsage bool ) {
70
+ processTools := func (noAction string , prompt string , req * schema.OpenAIRequest , config * config.ModelConfig , loader * model.ModelLoader , responses chan schema.OpenAIResponse , extraUsage bool ) error {
70
71
result := ""
71
- _ , tokenUsage , _ := ComputeChoices (req , prompt , config , cl , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , usage backend.TokenUsage ) bool {
72
+ _ , tokenUsage , err := ComputeChoices (req , prompt , config , cl , startupOptions , loader , func (s string , c * []schema.Choice ) {}, func (s string , usage backend.TokenUsage ) bool {
72
73
result += s
73
74
// TODO: Change generated BNF grammar to be compliant with the schema so we can
74
75
// stream the result token by token here.
75
76
return true
76
77
})
77
-
78
+ if err != nil {
79
+ return err
80
+ }
78
81
textContentToReturn = functions .ParseTextContent (result , config .FunctionsConfig )
79
82
result = functions .CleanupLLMResult (result , config .FunctionsConfig )
80
83
functionResults := functions .ParseFunctionCall (result , config .FunctionsConfig )
@@ -95,7 +98,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
95
98
result , err := handleQuestion (config , cl , req , ml , startupOptions , functionResults , result , prompt )
96
99
if err != nil {
97
100
log .Error ().Err (err ).Msg ("error handling question" )
98
- return
101
+ return err
99
102
}
100
103
usage := schema.OpenAIUsage {
101
104
PromptTokens : tokenUsage .Prompt ,
@@ -169,6 +172,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
169
172
}
170
173
171
174
close (responses )
175
+ return err
172
176
}
173
177
174
178
return func (c * fiber.Ctx ) error {
@@ -223,9 +227,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
223
227
if err != nil {
224
228
return err
225
229
}
226
- if d .Type == "json_object" {
230
+
231
+ switch d .Type {
232
+ case "json_object" :
227
233
input .Grammar = functions .JSONBNF
228
- } else if d . Type == "json_schema" {
234
+ case "json_schema" :
229
235
d := schema.JsonSchemaRequest {}
230
236
dat , err := json .Marshal (config .ResponseFormatMap )
231
237
if err != nil {
@@ -326,31 +332,69 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
326
332
c .Set ("X-Correlation-ID" , id )
327
333
328
334
responses := make (chan schema.OpenAIResponse )
335
+ ended := make (chan error , 1 )
329
336
330
- if ! shouldUseFn {
331
- go process (predInput , input , config , ml , responses , extraUsage )
332
- } else {
333
- go processTools (noActionName , predInput , input , config , ml , responses , extraUsage )
334
- }
337
+ go func () {
338
+ if ! shouldUseFn {
339
+ ended <- process (predInput , input , config , ml , responses , extraUsage )
340
+ } else {
341
+ ended <- processTools (noActionName , predInput , input , config , ml , responses , extraUsage )
342
+ }
343
+ }()
335
344
336
345
c .Context ().SetBodyStreamWriter (fasthttp .StreamWriter (func (w * bufio.Writer ) {
337
346
usage := & schema.OpenAIUsage {}
338
347
toolsCalled := false
339
- for ev := range responses {
340
- usage = & ev .Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
341
- if len (ev .Choices [0 ].Delta .ToolCalls ) > 0 {
342
- toolsCalled = true
343
- }
344
- var buf bytes.Buffer
345
- enc := json .NewEncoder (& buf )
346
- enc .Encode (ev )
347
- log .Debug ().Msgf ("Sending chunk: %s" , buf .String ())
348
- _ , err := fmt .Fprintf (w , "data: %v\n " , buf .String ())
349
- if err != nil {
350
- log .Debug ().Msgf ("Sending chunk failed: %v" , err )
351
- input .Cancel ()
348
+
349
+ LOOP:
350
+ for {
351
+ select {
352
+ case ev := <- responses :
353
+ if len (ev .Choices ) == 0 {
354
+ log .Debug ().Msgf ("No choices in the response, skipping" )
355
+ continue
356
+ }
357
+ usage = & ev .Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
358
+ if len (ev .Choices [0 ].Delta .ToolCalls ) > 0 {
359
+ toolsCalled = true
360
+ }
361
+ var buf bytes.Buffer
362
+ enc := json .NewEncoder (& buf )
363
+ enc .Encode (ev )
364
+ log .Debug ().Msgf ("Sending chunk: %s" , buf .String ())
365
+ _ , err := fmt .Fprintf (w , "data: %v\n " , buf .String ())
366
+ if err != nil {
367
+ log .Debug ().Msgf ("Sending chunk failed: %v" , err )
368
+ input .Cancel ()
369
+ }
370
+ w .Flush ()
371
+ case err := <- ended :
372
+ if err == nil {
373
+ break LOOP
374
+ }
375
+ log .Error ().Msgf ("Stream ended with error: %v" , err )
376
+
377
+ resp := & schema.OpenAIResponse {
378
+ ID : id ,
379
+ Created : created ,
380
+ Model : input .Model , // we have to return what the user sent here, due to OpenAI spec.
381
+ Choices : []schema.Choice {
382
+ {
383
+ FinishReason : "stop" ,
384
+ Index : 0 ,
385
+ Delta : & schema.Message {Content : "Internal error: " + err .Error ()},
386
+ }},
387
+ Object : "chat.completion.chunk" ,
388
+ Usage : * usage ,
389
+ }
390
+ respData , _ := json .Marshal (resp )
391
+
392
+ w .WriteString (fmt .Sprintf ("data: %s\n \n " , respData ))
393
+ w .WriteString ("data: [DONE]\n \n " )
394
+ w .Flush ()
395
+
396
+ return
352
397
}
353
- w .Flush ()
354
398
}
355
399
356
400
finishReason := "stop"
@@ -378,7 +422,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
378
422
w .WriteString (fmt .Sprintf ("data: %s\n \n " , respData ))
379
423
w .WriteString ("data: [DONE]\n \n " )
380
424
w .Flush ()
425
+ log .Debug ().Msgf ("Stream ended" )
381
426
}))
427
+
382
428
return nil
383
429
384
430
// no streaming mode
0 commit comments