Skip to content

Commit cb57095

Browse files
authored
Add FunctionInvokingChatClient.AdditionalTools (#6661)
* Add FunctionInvokingChatClient.AdditionalTools>
1 parent ff06191 commit cb57095

File tree

4 files changed

+123
-12
lines changed

4 files changed

+123
-12
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,16 @@ public int MaximumConsecutiveErrorsPerRequest
205205
set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0);
206206
}
207207

208+
/// <summary>Gets or sets a collection of additional tools the client is able to invoke.</summary>
209+
/// <remarks>
210+
/// These will not impact the requests sent by the <see cref="FunctionInvokingChatClient"/>, which will pass through the
211+
/// <see cref="ChatOptions.Tools" /> unmodified. However, if the inner client requests the invocation of a tool
212+
/// that was not in <see cref="ChatOptions.Tools" />, this <see cref="AdditionalTools"/> collection will also be consulted
213+
/// to look for a corresponding tool to invoke. This is useful when the service may have been pre-configured to be aware
214+
/// of certain tools that aren't also sent on each individual request.
215+
/// </remarks>
216+
public IList<AITool>? AdditionalTools { get; set; }
217+
208218
/// <summary>Gets or sets a delegate used to invoke <see cref="AIFunction"/> instances.</summary>
209219
/// <remarks>
210220
/// By default, the protected <see cref="InvokeFunctionAsync"/> method is called for each <see cref="AIFunction"/> to be invoked,
@@ -250,7 +260,7 @@ public override async Task<ChatResponse> GetResponseAsync(
250260

251261
// Any function call work to do? If yes, ensure we're tracking that work in functionCallContents.
252262
bool requiresFunctionInvocation =
253-
options?.Tools is { Count: > 0 } &&
263+
(options?.Tools is { Count: > 0 } || AdditionalTools is { Count: > 0 }) &&
254264
iteration < MaximumIterationsPerRequest &&
255265
CopyFunctionCalls(response.Messages, ref functionCallContents);
256266

@@ -288,7 +298,7 @@ public override async Task<ChatResponse> GetResponseAsync(
288298

289299
// Add the responses from the function calls into the augmented history and also into the tracked
290300
// list of response messages.
291-
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken);
301+
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken);
292302
responseMessages.AddRange(modeAndMessages.MessagesAdded);
293303
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
294304

@@ -297,7 +307,7 @@ public override async Task<ChatResponse> GetResponseAsync(
297307
break;
298308
}
299309

300-
UpdateOptionsForNextIteration(ref options!, response.ConversationId);
310+
UpdateOptionsForNextIteration(ref options, response.ConversationId);
301311
}
302312

303313
Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages.");
@@ -367,7 +377,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
367377

368378
// If there are no tools to call, or for any other reason we should stop, return the response.
369379
if (functionCallContents is not { Count: > 0 } ||
370-
options?.Tools is not { Count: > 0 } ||
380+
(options?.Tools is not { Count: > 0 } && AdditionalTools is not { Count: > 0 }) ||
371381
iteration >= _maximumIterationsPerRequest)
372382
{
373383
break;
@@ -535,9 +545,16 @@ private static bool CopyFunctionCalls(
535545
return any;
536546
}
537547

538-
private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? conversationId)
548+
private static void UpdateOptionsForNextIteration(ref ChatOptions? options, string? conversationId)
539549
{
540-
if (options.ToolMode is RequiredChatToolMode)
550+
if (options is null)
551+
{
552+
if (conversationId is not null)
553+
{
554+
options = new() { ConversationId = conversationId };
555+
}
556+
}
557+
else if (options.ToolMode is RequiredChatToolMode)
541558
{
542559
// We have to reset the tool mode to be non-required after the first iteration,
543560
// as otherwise we'll be in an infinite loop.
@@ -566,7 +583,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
566583
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
567584
/// <returns>A value indicating how the caller should proceed.</returns>
568585
private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList<ChatMessage> MessagesAdded)> ProcessFunctionCallsAsync(
569-
List<ChatMessage> messages, ChatOptions options, List<FunctionCallContent> functionCallContents, int iteration, int consecutiveErrorCount,
586+
List<ChatMessage> messages, ChatOptions? options, List<FunctionCallContent> functionCallContents, int iteration, int consecutiveErrorCount,
570587
bool isStreaming, CancellationToken cancellationToken)
571588
{
572589
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
@@ -695,13 +712,13 @@ private void ThrowIfNoFunctionResultsAdded(IList<ChatMessage>? messages)
695712
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
696713
/// <returns>A value indicating how the caller should proceed.</returns>
697714
private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
698-
List<ChatMessage> messages, ChatOptions options, List<FunctionCallContent> callContents,
715+
List<ChatMessage> messages, ChatOptions? options, List<FunctionCallContent> callContents,
699716
int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken)
700717
{
701718
var callContent = callContents[functionCallIndex];
702719

703720
// Look up the AIFunction for the function call. If the requested function isn't available, send back an error.
704-
AIFunction? aiFunction = options.Tools!.OfType<AIFunction>().FirstOrDefault(t => t.Name == callContent.Name);
721+
AIFunction? aiFunction = FindAIFunction(options?.Tools, callContent.Name) ?? FindAIFunction(AdditionalTools, callContent.Name);
705722
if (aiFunction is null)
706723
{
707724
return new(terminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
@@ -746,6 +763,23 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
746763
callContent,
747764
result,
748765
exception: null);
766+
767+
static AIFunction? FindAIFunction(IList<AITool>? tools, string functionName)
768+
{
769+
if (tools is not null)
770+
{
771+
int count = tools.Count;
772+
for (int i = 0; i < count; i++)
773+
{
774+
if (tools[i] is AIFunction function && function.Name == functionName)
775+
{
776+
return function;
777+
}
778+
}
779+
}
780+
781+
return null;
782+
}
749783
}
750784

751785
/// <summary>Creates one or more response messages for function invocation results.</summary>

src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,10 @@
515515
}
516516
],
517517
"Properties": [
518+
{
519+
"Member": "System.Collections.Generic.IList<Microsoft.Extensions.AI.AITool>? Microsoft.Extensions.AI.FunctionInvokingChatClient.AdditionalTools { get; set; }",
520+
"Stage": "Stable"
521+
},
518522
{
519523
"Member": "bool Microsoft.Extensions.AI.FunctionInvokingChatClient.AllowConcurrentInvocation { get; set; }",
520524
"Stage": "Stable"

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public void Ctor_HasExpectedDefaults()
3939
Assert.Equal(40, client.MaximumIterationsPerRequest);
4040
Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest);
4141
Assert.Null(client.FunctionInvoker);
42+
Assert.Null(client.AdditionalTools);
4243
}
4344

4445
[Fact]
@@ -67,6 +68,11 @@ public void Properties_Roundtrip()
6768
Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> invoker = (ctx, ct) => new ValueTask<object?>("test");
6869
client.FunctionInvoker = invoker;
6970
Assert.Same(invoker, client.FunctionInvoker);
71+
72+
Assert.Null(client.AdditionalTools);
73+
IList<AITool> additionalTools = [AIFunctionFactory.Create(() => "Additional Tool")];
74+
client.AdditionalTools = additionalTools;
75+
Assert.Same(additionalTools, client.AdditionalTools);
7076
}
7177

7278
[Fact]
@@ -99,6 +105,73 @@ public async Task SupportsSingleFunctionCallPerRequestAsync()
99105
await InvokeAndAssertStreamingAsync(options, plan);
100106
}
101107

108+
[Theory]
109+
[InlineData(false)]
110+
[InlineData(true)]
111+
public async Task SupportsToolsProvidedByAdditionalTools(bool provideOptions)
112+
{
113+
ChatOptions? options = provideOptions ?
114+
new() { Tools = [AIFunctionFactory.Create(() => "Shouldn't be invoked", "ChatOptionsFunc")] } :
115+
null;
116+
117+
Func<ChatClientBuilder, ChatClientBuilder> configure = builder =>
118+
builder.UseFunctionInvocation(configure: c => c.AdditionalTools =
119+
[
120+
AIFunctionFactory.Create(() => "Result 1", "Func1"),
121+
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
122+
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
123+
]);
124+
125+
List<ChatMessage> plan =
126+
[
127+
new ChatMessage(ChatRole.User, "hello"),
128+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
129+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
130+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
131+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
132+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
133+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
134+
new ChatMessage(ChatRole.Assistant, "world"),
135+
];
136+
137+
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
138+
139+
await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
140+
}
141+
142+
[Fact]
143+
public async Task PrefersToolsProvidedByChatOptions()
144+
{
145+
ChatOptions options = new()
146+
{
147+
Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")]
148+
};
149+
150+
Func<ChatClientBuilder, ChatClientBuilder> configure = builder =>
151+
builder.UseFunctionInvocation(configure: c => c.AdditionalTools =
152+
[
153+
AIFunctionFactory.Create(() => "Should never be invoked", "Func1"),
154+
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
155+
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
156+
]);
157+
158+
List<ChatMessage> plan =
159+
[
160+
new ChatMessage(ChatRole.User, "hello"),
161+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
162+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
163+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
164+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
165+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
166+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
167+
new ChatMessage(ChatRole.Assistant, "world"),
168+
];
169+
170+
await InvokeAndAssertAsync(options, plan, configurePipeline: configure);
171+
172+
await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
173+
}
174+
102175
[Theory]
103176
[InlineData(false)]
104177
[InlineData(true)]
@@ -1002,7 +1075,7 @@ public override void Post(SendOrPostCallback d, object? state)
10021075
}
10031076

10041077
private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
1005-
ChatOptions options,
1078+
ChatOptions? options,
10061079
List<ChatMessage> plan,
10071080
List<ChatMessage>? expected = null,
10081081
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,
@@ -1102,7 +1175,7 @@ private static UsageDetails CreateRandomUsage()
11021175
}
11031176

11041177
private static async Task<List<ChatMessage>> InvokeAndAssertStreamingAsync(
1105-
ChatOptions options,
1178+
ChatOptions? options,
11061179
List<ChatMessage> plan,
11071180
List<ChatMessage>? expected = null,
11081181
Func<ChatClientBuilder, ChatClientBuilder>? configurePipeline = null,

test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
</PropertyGroup>
66

77
<PropertyGroup>
8-
<NoWarn>$(NoWarn);CA1063;CA1861;SA1130;VSTHRD003</NoWarn>
8+
<NoWarn>$(NoWarn);CA1063;CA1861;S104;SA1130;VSTHRD003</NoWarn>
99
<NoWarn>$(NoWarn);MEAI001</NoWarn>
1010
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
1111
</PropertyGroup>

0 commit comments

Comments
 (0)