diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 293ff1d98f1..6b1d3b3e905 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -205,6 +205,15 @@ public int MaximumConsecutiveErrorsPerRequest set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0); } + /// Gets or sets a delegate used to invoke instances. + /// + /// By default, the protected method is called for each to be invoked, + /// invoking the instance and returning its result. If this delegate is set to a non- value, + /// will replace its normal invocation with a call to this delegate, enabling + /// this delegate to assume all invocation handling of the function. + /// + public Func>? FunctionInvoker { get; set; } + /// public override async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) @@ -872,7 +881,9 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul { _ = Throw.IfNull(context); - return context.Function.InvokeAsync(context.Arguments, cancellationToken); + return FunctionInvoker is { } invoker ? + invoker(context, cancellationToken) : + context.Function.InvokeAsync(context.Arguments, cancellationToken); } private static TimeSpan GetElapsedTime(long startingTimestamp) => diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json index 4f4317c9978..59ed3d32fab 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json @@ -527,6 +527,10 @@ "Member": "System.IServiceProvider? Microsoft.Extensions.AI.FunctionInvokingChatClient.FunctionInvocationServices { get; }", "Stage": "Stable" }, + { + "Member": "System.Func>? Microsoft.Extensions.AI.FunctionInvokingChatClient.FunctionInvoker { get; set; }", + "Stage": "Stable" + }, { "Member": "bool Microsoft.Extensions.AI.FunctionInvokingChatClient.IncludeDetailedErrors { get; set; }", "Stage": "Stable" diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 26554946dca..1379cef8bf0 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -37,6 +38,35 @@ public void Ctor_HasExpectedDefaults() Assert.False(client.IncludeDetailedErrors); Assert.Equal(10, client.MaximumIterationsPerRequest); Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest); + Assert.Null(client.FunctionInvoker); + } + + [Fact] + public void Properties_Roundtrip() + { + using TestChatClient innerClient = new(); + using FunctionInvokingChatClient client = new(innerClient); + + Assert.False(client.AllowConcurrentInvocation); + client.AllowConcurrentInvocation = true; + Assert.True(client.AllowConcurrentInvocation); + + Assert.False(client.IncludeDetailedErrors); + client.IncludeDetailedErrors = true; + Assert.True(client.IncludeDetailedErrors); + + Assert.Equal(10, client.MaximumIterationsPerRequest); + client.MaximumIterationsPerRequest = 5; + Assert.Equal(5, client.MaximumIterationsPerRequest); + + Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest); + client.MaximumConsecutiveErrorsPerRequest = 1; + Assert.Equal(1, client.MaximumConsecutiveErrorsPerRequest); + + Assert.Null(client.FunctionInvoker); + Func> invoker = (ctx, ct) => new ValueTask("test"); + client.FunctionInvoker = invoker; + Assert.Same(invoker, client.FunctionInvoker); } [Fact] @@ -208,6 +238,49 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() await InvokeAndAssertStreamingAsync(options, plan); } + [Fact] + public async Task FunctionInvokerDelegateOverridesHandlingAsync() + { + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ] + }; + + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1 from delegate")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42 from delegate")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + Func configure = b => b.Use( + s => new FunctionInvokingChatClient(s) + { + FunctionInvoker = async (ctx, cancellationToken) => + { + Assert.NotNull(ctx); + var result = await ctx.Function.InvokeAsync(ctx.Arguments, cancellationToken); + return result is JsonElement e ? + JsonSerializer.SerializeToElement($"{e.GetString()} from delegate", AIJsonUtilities.DefaultOptions) : + result; + } + }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); + } + [Fact] public async Task ContinuesWithSuccessfulCallsUntilMaximumIterations() {