Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,15 @@ public int MaximumConsecutiveErrorsPerRequest
set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0);
}

/// <summary>Gets or sets a delegate used to invoke <see cref="AIFunction"/> instances.</summary>
/// <remarks>
/// By default, the protected <see cref="InvokeFunctionAsync"/> method is called for each <see cref="AIFunction"/> to be invoked,
/// invoking the instance and returning its result. If this delegate is set to a non-<see langword="null"/> value,
/// <see cref="InvokeFunctionAsync"/> will replace its normal invocation with a call to this delegate, enabling
/// this delegate to assume all invocation handling of the function.
/// </remarks>
public Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>>? FunctionInvoker { get; set; }

/// <inheritdoc/>
public override async Task<ChatResponse> GetResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -872,7 +881,9 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
{
_ = Throw.IfNull(context);

return context.Function.InvokeAsync(context.Arguments, cancellationToken);
return FunctionInvoker is { } invoker ?
invoker(context, cancellationToken) :
context.Function.InvokeAsync(context.Arguments, cancellationToken);
}

private static TimeSpan GetElapsedTime(long startingTimestamp) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,10 @@
"Member": "System.IServiceProvider? Microsoft.Extensions.AI.FunctionInvokingChatClient.FunctionInvocationServices { get; }",
"Stage": "Stable"
},
{
"Member": "System.Func<Microsoft.Extensions.AI.FunctionInvocationContext, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<object?>>? Microsoft.Extensions.AI.FunctionInvokingChatClient.FunctionInvoker { get; set; }",
"Stage": "Stable"
},
{
"Member": "bool Microsoft.Extensions.AI.FunctionInvokingChatClient.IncludeDetailedErrors { get; set; }",
"Stage": "Stable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -37,6 +38,35 @@ public void Ctor_HasExpectedDefaults()
Assert.False(client.IncludeDetailedErrors);
Assert.Equal(10, client.MaximumIterationsPerRequest);
Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest);
Assert.Null(client.FunctionInvoker);
}

[Fact]
public void Properties_Roundtrip()
{
using TestChatClient innerClient = new();
using FunctionInvokingChatClient client = new(innerClient);

Assert.False(client.AllowConcurrentInvocation);
client.AllowConcurrentInvocation = true;
Assert.True(client.AllowConcurrentInvocation);

Assert.False(client.IncludeDetailedErrors);
client.IncludeDetailedErrors = true;
Assert.True(client.IncludeDetailedErrors);

Assert.Equal(10, client.MaximumIterationsPerRequest);
client.MaximumIterationsPerRequest = 5;
Assert.Equal(5, client.MaximumIterationsPerRequest);

Assert.Equal(3, client.MaximumConsecutiveErrorsPerRequest);
client.MaximumConsecutiveErrorsPerRequest = 1;
Assert.Equal(1, client.MaximumConsecutiveErrorsPerRequest);

Assert.Null(client.FunctionInvoker);
Func<FunctionInvocationContext, CancellationToken, ValueTask<object?>> invoker = (ctx, ct) => new ValueTask<object?>("test");
client.FunctionInvoker = invoker;
Assert.Same(invoker, client.FunctionInvoker);
}

[Fact]
Expand Down Expand Up @@ -208,6 +238,49 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync()
await InvokeAndAssertStreamingAsync(options, plan);
}

[Fact]
public async Task FunctionInvokerDelegateOverridesHandlingAsync()
{
var options = new ChatOptions
{
Tools =
[
AIFunctionFactory.Create(() => "Result 1", "Func1"),
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
]
};

List<ChatMessage> plan =
[
new ChatMessage(ChatRole.User, "hello"),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1 from delegate")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42 from delegate")]),
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
new ChatMessage(ChatRole.Assistant, "world"),
];

Func<ChatClientBuilder, ChatClientBuilder> configure = b => b.Use(
s => new FunctionInvokingChatClient(s)
{
FunctionInvoker = async (ctx, cancellationToken) =>
{
Assert.NotNull(ctx);
var result = await ctx.Function.InvokeAsync(ctx.Arguments, cancellationToken);
return result is JsonElement e ?
JsonSerializer.SerializeToElement($"{e.GetString()} from delegate", AIJsonUtilities.DefaultOptions) :
result;
}
});

await InvokeAndAssertAsync(options, plan, configurePipeline: configure);

await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure);
}

[Fact]
public async Task ContinuesWithSuccessfulCallsUntilMaximumIterations()
{
Expand Down
Loading