From 43fd73806b80adb7922cf006e4bc84ebb2b2ed7b Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 17 Apr 2025 10:29:22 -0400 Subject: [PATCH] Utilize IServiceProviderIsService in AIFunctionFactory Add AIFunctionFactoryOptions.Services, and use it when examining function parameters to determine whether they should be resolved by default from DI. --- .../Functions/AIFunctionFactory.cs | 117 ++++++++++++++++-- .../Functions/AIFunctionFactoryOptions.cs | 10 ++ .../Functions/AIFunctionFactoryTest.cs | 60 ++++++++- 3 files changed, 171 insertions(+), 16 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index 29920409fbd..4878239f35b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -80,6 +80,16 @@ public static partial class AIFunctionFactory /// The handling of such parameters may be overridden via . /// /// + /// + /// + /// When the is constructed, it may be passed an via + /// . Any parameter that can be satisfied by that + /// according to will not be included in the generated JSON schema and will be resolved + /// from the provided to via , + /// rather than from the argument collection. The handling of such parameters may be overridden via + /// . + /// + /// /// /// All other parameter types are, by default, bound from the dictionary passed into /// and are included in the generated JSON schema. This may be overridden by the provided @@ -168,6 +178,15 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter. /// /// + /// + /// + /// When the is constructed, it may be passed an via + /// . Any parameter that can be satisfied by that + /// according to will not be included in the generated JSON schema and will be resolved + /// from the provided to via , + /// rather than from the argument collection. + /// + /// /// /// All other parameter types are bound from the dictionary passed into /// and are included in the generated JSON schema. @@ -260,6 +279,16 @@ public static AIFunction Create(Delegate method, string? name = null, string? de /// The handling of such parameters may be overridden via . /// /// + /// + /// + /// When the is constructed, it may be passed an via + /// . Any parameter that can be satisfied by that + /// according to will not be included in the generated JSON schema and will be resolved + /// from the provided to via , + /// rather than from the argument collection. The handling of such parameters may be overridden via + /// . + /// + /// /// /// All other parameter types are, by default, bound from the dictionary passed into /// and are included in the generated JSON schema. This may be overridden by the provided @@ -357,6 +386,15 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac /// is allowed to be ; otherwise, /// must be non-, or else the invocation will fail with an exception due to the required nature of the parameter. /// + /// + /// + /// When the is constructed, it may be passed an via + /// . Any parameter that can be satisfied by that + /// according to will not be included in the generated JSON schema and will be resolved + /// from the provided to via , + /// rather than from the argument collection. + /// + /// /// /// /// All other parameter types are bound from the dictionary passed into @@ -465,6 +503,16 @@ public static AIFunction Create(MethodInfo method, object? target, string? name /// The handling of such parameters may be overridden via . /// /// + /// + /// + /// When the is constructed, it may be passed an via + /// . Any parameter that can be satisfied by that + /// according to will not be included in the generated JSON schema and will be resolved + /// from the provided to via , + /// rather than from the argument collection. The handling of such parameters may be overridden via + /// . + /// + /// /// /// All other parameter types are, by default, bound from the dictionary passed into /// and are included in the generated JSON schema. This may be overridden by the provided @@ -661,7 +709,7 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu serializerOptions.MakeReadOnly(); ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions); - DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, schemaOptions); + DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, options.Services, schemaOptions); if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor)) { return descriptor; @@ -688,6 +736,8 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions } } + IServiceProviderIsService? serviceProviderIsService = key.Services?.GetService(); + // Use that binding information to impact the schema generation. AIJsonSchemaCreateOptions schemaOptions = key.SchemaOptions with { @@ -714,6 +764,14 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions return false; } + // We assume that if the services used to create the function support a particular type, + // so too do the services that will be passed into InvokeAsync. This is the same basic assumption + // made in ASP.NET. + if (serviceProviderIsService?.IsService(parameterInfo.ParameterType) is true) + { + return false; + } + // If there was an existing IncludeParameter delegate, now defer to it as we've // excluded everything we need to exclude. if (key.SchemaOptions.IncludeParameter is { } existingIncludeParameter) @@ -735,7 +793,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions options = default; } - ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]); + ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i], serviceProviderIsService); } // Get a marshaling delegate for the return value. @@ -805,7 +863,8 @@ static bool IsAsyncMethod(MethodInfo method) private static Func GetParameterMarshaller( JsonSerializerOptions serializerOptions, AIFunctionFactoryOptions.ParameterBindingOptions bindingOptions, - ParameterInfo parameter) + ParameterInfo parameter, + IServiceProviderIsService? serviceProviderIsService) { if (string.IsNullOrWhiteSpace(parameter.Name)) { @@ -831,28 +890,28 @@ static bool IsAsyncMethod(MethodInfo method) // We're now into default handling of everything else. - // For AIFunctionArgument parameters, we bind to the arguments passed directly to InvokeAsync. + // For AIFunctionArgument parameters, we bind to the arguments passed to InvokeAsync. if (parameterType == typeof(AIFunctionArguments)) { return static (arguments, _) => arguments; } - // For IServiceProvider parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments. + // For IServiceProvider parameters, we bind to the services passed to InvokeAsync via AIFunctionArguments. if (parameterType == typeof(IServiceProvider)) { return (arguments, _) => { IServiceProvider? services = arguments.Services; - if (services is null && !parameter.HasDefaultValue) + if (!parameter.HasDefaultValue && services is null) { - Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter."); + ThrowNullServices(parameter.Name); } return services; }; } - // For [FromKeyedServices] parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments. + // For [FromKeyedServices] parameters, we resolve from the services passed to InvokeAsync via AIFunctionArguments. if (parameter.GetCustomAttribute(inherit: true) is { } keyedAttr) { return (arguments, _) => @@ -864,7 +923,38 @@ static bool IsAsyncMethod(MethodInfo method) if (!parameter.HasDefaultValue) { - Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found."); + if (arguments.Services is null) + { + ThrowNullServices(parameter.Name); + } + + Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found for parameter '{parameter.Name}'."); + } + + return parameter.DefaultValue; + }; + } + + // For any parameters that are satisfiable from the IServiceProvider, we resolve from the services passed to InvokeAsync + // via AIFunctionArguments. This is determined by the same same IServiceProviderIsService instance used to determine whether + // the parameter should be included in the schema. + if (serviceProviderIsService?.IsService(parameterType) is true) + { + return (arguments, _) => + { + if (arguments.Services?.GetService(parameterType) is { } service) + { + return service; + } + + if (!parameter.HasDefaultValue) + { + if (arguments.Services is null) + { + ThrowNullServices(parameter.Name); + } + + Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' was found for parameter '{parameter.Name}'."); } return parameter.DefaultValue; @@ -873,7 +963,7 @@ static bool IsAsyncMethod(MethodInfo method) // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. // Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found. - JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType); + JsonTypeInfo? typeInfo = serializerOptions.GetTypeInfo(parameterType); return (arguments, _) => { // If the parameter has an argument specified in the dictionary, return that argument. @@ -907,12 +997,16 @@ static bool IsAsyncMethod(MethodInfo method) // If the parameter is required and there's no argument specified for it, throw. if (!parameter.HasDefaultValue) { - Throw.ArgumentException(nameof(arguments), $"Missing required parameter '{parameter.Name}' for method '{parameter.Member.Name}'."); + Throw.ArgumentException(nameof(arguments), $"The arguments dictionary is missing a value for the required parameter '{parameter.Name}'."); } // Otherwise, use the optional parameter's default value. return parameter.DefaultValue; }; + + // Throws an ArgumentNullException indicating that AIFunctionArguments.Services must be provided. + static void ThrowNullServices(string parameterName) => + Throw.ArgumentNullException($"arguments.{nameof(AIFunctionArguments.Services)}", $"Services are required for parameter '{parameterName}'."); } /// @@ -1075,6 +1169,7 @@ private record struct DescriptorKey( string? Description, Func? GetBindParameterOptions, Func>? MarshalResult, + IServiceProvider? Services, AIJsonSchemaCreateOptions SchemaOptions); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs index e71a4687422..80c8b485c59 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs @@ -106,6 +106,16 @@ public AIFunctionFactoryOptions() /// public Func>? MarshalResult { get; set; } + /// + /// Gets or sets optional services used in the construction of the . + /// + /// + /// These services will be used to determine which parameters should be satisifed from dependency injection. As such, + /// what services are satisfied via this provider should match what's satisfied via the provider passed into + /// via . + /// + public IServiceProvider? Services { get; set; } + /// Provides configuration options produced by the delegate. public readonly record struct ParameterBindingOptions { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs index b983189dac5..0670a06b206 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -14,6 +14,7 @@ #pragma warning disable IDE0004 // Remove Unnecessary Cast #pragma warning disable S107 // Methods should not have too many parameters +#pragma warning disable S2760 // Sequential tests should not check the same condition #pragma warning disable S3358 // Ternary operators should not be nested #pragma warning disable S5034 // "ValueTask" should be consumed correctly @@ -263,7 +264,7 @@ public async Task AIFunctionArguments_SatisfiesParameters() Assert.DoesNotContain("services", func.JsonSchema.ToString()); Assert.DoesNotContain("arguments", func.JsonSchema.ToString()); - await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(arguments).AsTask()); + await Assert.ThrowsAsync("arguments.Services", () => func.InvokeAsync(arguments).AsTask()); arguments.Services = sp; var result = await func.InvokeAsync(arguments); @@ -298,6 +299,55 @@ public async Task AIFunctionArguments_MissingServicesMayBeOptional() Assert.Equal("", result?.ToString()); } + [Fact] + public async Task IServiceProvider_ServicesInOptionsImpactsFunctionCreation() + { + ServiceCollection sc = new(); + sc.AddSingleton(new MyService(123)); + IServiceProvider sp = sc.BuildServiceProvider(); + + AIFunction func; + + // Services not provided to Create, non-optional argument + if (JsonSerializer.IsReflectionEnabledByDefault) + { + func = AIFunctionFactory.Create((MyService myService) => myService.Value); + Assert.Contains("myService", func.JsonSchema.ToString()); + await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(new()).AsTask()); + await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(new() { Services = sp }).AsTask()); + } + else + { + Assert.Throws(() => AIFunctionFactory.Create((MyService myService) => myService.Value)); + } + + // Services not provided to Create, optional argument + if (JsonSerializer.IsReflectionEnabledByDefault) + { + func = AIFunctionFactory.Create((MyService? myService = null) => myService?.Value ?? 456); + Assert.Contains("myService", func.JsonSchema.ToString()); + Assert.Contains("456", (await func.InvokeAsync(new()))?.ToString()); + Assert.Contains("456", (await func.InvokeAsync(new() { Services = sp }))?.ToString()); + } + else + { + Assert.Throws(() => AIFunctionFactory.Create((MyService myService) => myService.Value)); + } + + // Services provided to Create, non-optional argument + func = AIFunctionFactory.Create((MyService myService) => myService.Value, new() { Services = sp }); + Assert.DoesNotContain("myService", func.JsonSchema.ToString()); + await Assert.ThrowsAsync("arguments.Services", () => func.InvokeAsync(new()).AsTask()); + await Assert.ThrowsAsync("arguments", () => func.InvokeAsync(new() { Services = new ServiceCollection().BuildServiceProvider() }).AsTask()); + Assert.Contains("123", (await func.InvokeAsync(new() { Services = sp }))?.ToString()); + + // Services provided to Create, optional argument + func = AIFunctionFactory.Create((MyService? myService = null) => myService?.Value ?? 456, new() { Services = sp }); + Assert.DoesNotContain("myService", func.JsonSchema.ToString()); + Assert.Contains("456", (await func.InvokeAsync(new()))?.ToString()); + Assert.Contains("123", (await func.InvokeAsync(new() { Services = sp }))?.ToString()); + } + [Fact] public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable() { @@ -440,8 +490,8 @@ public async Task FromKeyedServices_ResolvesFromServiceProvider() Assert.Contains("myInteger", f.JsonSchema.ToString()); Assert.DoesNotContain("service", f.JsonSchema.ToString()); - Exception e = await Assert.ThrowsAsync(() => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); - Assert.Contains("No service of type", e.Message); + Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); + Assert.Contains("Services are required", e.Message); var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp }); Assert.Contains("43", result?.ToString()); @@ -461,8 +511,8 @@ public async Task FromKeyedServices_NullKeysBindToNonKeyedServices() Assert.Contains("myInteger", f.JsonSchema.ToString()); Assert.DoesNotContain("service", f.JsonSchema.ToString()); - Exception e = await Assert.ThrowsAsync(() => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); - Assert.Contains("No service of type", e.Message); + Exception e = await Assert.ThrowsAsync("arguments.Services", () => f.InvokeAsync(new() { ["myInteger"] = 1 }).AsTask()); + Assert.Contains("Services are required", e.Message); var result = await f.InvokeAsync(new() { ["myInteger"] = 1, Services = sp }); Assert.Contains("43", result?.ToString());