Skip to content

Commit 7a2ce7b

Browse files
authored
Translate OpenAI refusals to ErrorContent (#6393)
Refusals in OpenAI are errors reported when the service can't generate an output that matches the requested schema. Translate refusals to ErrorContent now that we have it.
1 parent 7b2620c commit 7a2ce7b

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,21 +155,22 @@ void IDisposable.Dispose()
155155

156156
foreach (var content in input.Contents)
157157
{
158-
if (content is FunctionCallContent callRequest)
158+
switch (content)
159159
{
160-
message.ToolCalls.Add(
161-
ChatToolCall.CreateFunctionToolCall(
162-
callRequest.CallId,
163-
callRequest.Name,
164-
new(JsonSerializer.SerializeToUtf8Bytes(
165-
callRequest.Arguments,
166-
options.GetTypeInfo(typeof(IDictionary<string, object?>))))));
167-
}
168-
}
160+
case ErrorContent errorContent when errorContent.ErrorCode is nameof(message.Refusal):
161+
message.Refusal = errorContent.Message;
162+
break;
169163

170-
if (input.AdditionalProperties?.TryGetValue(nameof(message.Refusal), out string? refusal) is true)
171-
{
172-
message.Refusal = refusal;
164+
case FunctionCallContent callRequest:
165+
message.ToolCalls.Add(
166+
ChatToolCall.CreateFunctionToolCall(
167+
callRequest.CallId,
168+
callRequest.Name,
169+
new(JsonSerializer.SerializeToUtf8Bytes(
170+
callRequest.Arguments,
171+
options.GetTypeInfo(typeof(IDictionary<string, object?>))))));
172+
break;
173+
}
173174
}
174175

175176
yield return message;
@@ -370,7 +371,7 @@ private static async IAsyncEnumerable<ChatResponseUpdate> FromOpenAIStreamingCha
370371
// add it to this function calling item.
371372
if (refusal is not null)
372373
{
373-
(responseUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString();
374+
responseUpdate.Contents.Add(new ErrorContent(refusal.ToString()) { ErrorCode = "Refusal" });
374375
}
375376

376377
// Propagate additional relevant metadata.
@@ -450,6 +451,12 @@ private static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComple
450451
}
451452
}
452453

454+
// And add error content for any refusals, which represent errors in generating output that conforms to a provided schema.
455+
if (openAICompletion.Refusal is string refusal)
456+
{
457+
returnMessage.Contents.Add(new ErrorContent(refusal) { ErrorCode = nameof(openAICompletion.Refusal) });
458+
}
459+
453460
// Wrap the content in a ChatResponse to return.
454461
var response = new ChatResponse(returnMessage)
455462
{
@@ -470,11 +477,6 @@ private static ChatResponse FromOpenAIChatCompletion(ChatCompletion openAIComple
470477
(response.AdditionalProperties ??= [])[nameof(openAICompletion.ContentTokenLogProbabilities)] = contentTokenLogProbs;
471478
}
472479

473-
if (openAICompletion.Refusal is string refusal)
474-
{
475-
(response.AdditionalProperties ??= [])[nameof(openAICompletion.Refusal)] = refusal;
476-
}
477-
478480
if (openAICompletion.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs)
479481
{
480482
(response.AdditionalProperties ??= [])[nameof(openAICompletion.RefusalTokenLogProbabilities)] = refusalTokenLogProbs;

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponseChatClient.cs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
263263
MessageId = lastMessageId,
264264
ModelId = modelId,
265265
ResponseId = responseId,
266+
Role = lastRole,
266267
ConversationId = responseId,
267268
Contents =
268269
[
@@ -274,6 +275,19 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
274275
],
275276
};
276277
break;
278+
279+
case StreamingResponseRefusalDoneUpdate refusalDone:
280+
yield return new ChatResponseUpdate
281+
{
282+
CreatedAt = createdAt,
283+
MessageId = lastMessageId,
284+
ModelId = modelId,
285+
ResponseId = responseId,
286+
Role = lastRole,
287+
ConversationId = responseId,
288+
Contents = [new ErrorContent(refusalDone.Refusal) { ErrorCode = nameof(ResponseContentPart.Refusal) }],
289+
};
290+
break;
277291
}
278292
}
279293
}
@@ -539,9 +553,15 @@ private static List<AIContent> ToAIContents(IEnumerable<ResponseContentPart> con
539553

540554
foreach (ResponseContentPart part in contents)
541555
{
542-
if (part.Kind == ResponseContentPartKind.OutputText)
556+
switch (part.Kind)
543557
{
544-
results.Add(new TextContent(part.Text));
558+
case ResponseContentPartKind.OutputText:
559+
results.Add(new TextContent(part.Text));
560+
break;
561+
562+
case ResponseContentPartKind.Refusal:
563+
results.Add(new ErrorContent(part.Refusal) { ErrorCode = nameof(ResponseContentPartKind.Refusal) });
564+
break;
545565
}
546566
}
547567

@@ -572,6 +592,10 @@ private static List<ResponseContentPart> ToOpenAIResponsesContent(IList<AIConten
572592
parts.Add(ResponseContentPart.CreateInputFilePart(null, $"{Guid.NewGuid():N}.pdf",
573593
BinaryData.FromBytes(JsonSerializer.SerializeToUtf8Bytes(dataContent.Uri, ResponseClientJsonContext.Default.String))));
574594
break;
595+
596+
case ErrorContent errorContent when errorContent.ErrorCode == nameof(ResponseContentPartKind.Refusal):
597+
parts.Add(ResponseContentPart.CreateRefusalPart(errorContent.Message));
598+
break;
575599
}
576600
}
577601

0 commit comments

Comments
 (0)