Skip to content

Commit 79a94c6

Browse files
stephentoubjeffhandley
authored andcommitted
Utilize IServiceProviderIsService in AIFunctionFactory (#6317)
Add AIFunctionFactoryOptions.Services, and use it when examining function parameters to determine whether they should be resolved by default from DI.
1 parent 99deb9e commit 79a94c6

File tree

3 files changed

+171
-16
lines changed

3 files changed

+171
-16
lines changed

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ public static partial class AIFunctionFactory
8080
/// The handling of such parameters may be overridden via <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
8181
/// </description>
8282
/// </item>
83+
/// <item>
84+
/// <description>
85+
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
86+
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
87+
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
88+
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
89+
/// rather than from the argument collection. The handling of such parameters may be overridden via
90+
/// <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
91+
/// </description>
92+
/// </item>
8393
/// </list>
8494
/// All other parameter types are, by default, bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
8595
/// and are included in the generated JSON schema. This may be overridden by the <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/> provided
@@ -168,6 +178,15 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio
168178
/// must be non-<see langword="null"/>, or else the invocation will fail with an exception due to the required nature of the parameter.
169179
/// </description>
170180
/// </item>
181+
/// <item>
182+
/// <description>
183+
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
184+
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
185+
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
186+
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
187+
/// rather than from the argument collection.
188+
/// </description>
189+
/// </item>
171190
/// </list>
172191
/// All other parameter types are bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
173192
/// and are included in the generated JSON schema.
@@ -260,6 +279,16 @@ public static AIFunction Create(Delegate method, string? name = null, string? de
260279
/// The handling of such parameters may be overridden via <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
261280
/// </description>
262281
/// </item>
282+
/// <item>
283+
/// <description>
284+
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
285+
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
286+
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
287+
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
288+
/// rather than from the argument collection. The handling of such parameters may be overridden via
289+
/// <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
290+
/// </description>
291+
/// </item>
263292
/// </list>
264293
/// All other parameter types are, by default, bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
265294
/// and are included in the generated JSON schema. This may be overridden by the <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/> provided
@@ -357,6 +386,15 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
357386
/// <see cref="AIFunctionArguments.Services"/> is allowed to be <see langword="null"/>; otherwise, <see cref="AIFunctionArguments.Services"/>
358387
/// must be non-<see langword="null"/>, or else the invocation will fail with an exception due to the required nature of the parameter.
359388
/// </description>
389+
/// <item>
390+
/// <description>
391+
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
392+
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
393+
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
394+
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
395+
/// rather than from the argument collection.
396+
/// </description>
397+
/// </item>
360398
/// </item>
361399
/// </list>
362400
/// All other parameter types are bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
@@ -465,6 +503,16 @@ public static AIFunction Create(MethodInfo method, object? target, string? name
465503
/// The handling of such parameters may be overridden via <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
466504
/// </description>
467505
/// </item>
506+
/// <item>
507+
/// <description>
508+
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
509+
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
510+
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
511+
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
512+
/// rather than from the argument collection. The handling of such parameters may be overridden via
513+
/// <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
514+
/// </description>
515+
/// </item>
468516
/// </list>
469517
/// All other parameter types are, by default, bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
470518
/// and are included in the generated JSON schema. This may be overridden by the <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/> provided
@@ -661,7 +709,7 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu
661709
serializerOptions.MakeReadOnly();
662710
ConcurrentDictionary<DescriptorKey, ReflectionAIFunctionDescriptor> innerCache = _descriptorCache.GetOrCreateValue(serializerOptions);
663711

664-
DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, schemaOptions);
712+
DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, options.Services, schemaOptions);
665713
if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor))
666714
{
667715
return descriptor;
@@ -688,6 +736,8 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
688736
}
689737
}
690738

739+
IServiceProviderIsService? serviceProviderIsService = key.Services?.GetService<IServiceProviderIsService>();
740+
691741
// Use that binding information to impact the schema generation.
692742
AIJsonSchemaCreateOptions schemaOptions = key.SchemaOptions with
693743
{
@@ -714,6 +764,14 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
714764
return false;
715765
}
716766

767+
// We assume that if the services used to create the function support a particular type,
768+
// so too do the services that will be passed into InvokeAsync. This is the same basic assumption
769+
// made in ASP.NET.
770+
if (serviceProviderIsService?.IsService(parameterInfo.ParameterType) is true)
771+
{
772+
return false;
773+
}
774+
717775
// If there was an existing IncludeParameter delegate, now defer to it as we've
718776
// excluded everything we need to exclude.
719777
if (key.SchemaOptions.IncludeParameter is { } existingIncludeParameter)
@@ -735,7 +793,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
735793
options = default;
736794
}
737795

738-
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);
796+
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i], serviceProviderIsService);
739797
}
740798

741799
// Get a marshaling delegate for the return value.
@@ -805,7 +863,8 @@ static bool IsAsyncMethod(MethodInfo method)
805863
private static Func<AIFunctionArguments, CancellationToken, object?> GetParameterMarshaller(
806864
JsonSerializerOptions serializerOptions,
807865
AIFunctionFactoryOptions.ParameterBindingOptions bindingOptions,
808-
ParameterInfo parameter)
866+
ParameterInfo parameter,
867+
IServiceProviderIsService? serviceProviderIsService)
809868
{
810869
if (string.IsNullOrWhiteSpace(parameter.Name))
811870
{
@@ -831,28 +890,28 @@ static bool IsAsyncMethod(MethodInfo method)
831890

832891
// We're now into default handling of everything else.
833892

834-
// For AIFunctionArgument parameters, we bind to the arguments passed directly to InvokeAsync.
893+
// For AIFunctionArgument parameters, we bind to the arguments passed to InvokeAsync.
835894
if (parameterType == typeof(AIFunctionArguments))
836895
{
837896
return static (arguments, _) => arguments;
838897
}
839898

840-
// For IServiceProvider parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments.
899+
// For IServiceProvider parameters, we bind to the services passed to InvokeAsync via AIFunctionArguments.
841900
if (parameterType == typeof(IServiceProvider))
842901
{
843902
return (arguments, _) =>
844903
{
845904
IServiceProvider? services = arguments.Services;
846-
if (services is null && !parameter.HasDefaultValue)
905+
if (!parameter.HasDefaultValue && services is null)
847906
{
848-
Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter.");
907+
ThrowNullServices(parameter.Name);
849908
}
850909

851910
return services;
852911
};
853912
}
854913

855-
// For [FromKeyedServices] parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments.
914+
// For [FromKeyedServices] parameters, we resolve from the services passed to InvokeAsync via AIFunctionArguments.
856915
if (parameter.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is { } keyedAttr)
857916
{
858917
return (arguments, _) =>
@@ -864,7 +923,38 @@ static bool IsAsyncMethod(MethodInfo method)
864923

865924
if (!parameter.HasDefaultValue)
866925
{
867-
Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found.");
926+
if (arguments.Services is null)
927+
{
928+
ThrowNullServices(parameter.Name);
929+
}
930+
931+
Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found for parameter '{parameter.Name}'.");
932+
}
933+
934+
return parameter.DefaultValue;
935+
};
936+
}
937+
938+
// For any parameters that are satisfiable from the IServiceProvider, we resolve from the services passed to InvokeAsync
939+
// via AIFunctionArguments. This is determined by the same same IServiceProviderIsService instance used to determine whether
940+
// the parameter should be included in the schema.
941+
if (serviceProviderIsService?.IsService(parameterType) is true)
942+
{
943+
return (arguments, _) =>
944+
{
945+
if (arguments.Services?.GetService(parameterType) is { } service)
946+
{
947+
return service;
948+
}
949+
950+
if (!parameter.HasDefaultValue)
951+
{
952+
if (arguments.Services is null)
953+
{
954+
ThrowNullServices(parameter.Name);
955+
}
956+
957+
Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' was found for parameter '{parameter.Name}'.");
868958
}
869959

870960
return parameter.DefaultValue;
@@ -873,7 +963,7 @@ static bool IsAsyncMethod(MethodInfo method)
873963

874964
// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
875965
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
876-
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
966+
JsonTypeInfo? typeInfo = serializerOptions.GetTypeInfo(parameterType);
877967
return (arguments, _) =>
878968
{
879969
// If the parameter has an argument specified in the dictionary, return that argument.
@@ -907,12 +997,16 @@ static bool IsAsyncMethod(MethodInfo method)
907997
// If the parameter is required and there's no argument specified for it, throw.
908998
if (!parameter.HasDefaultValue)
909999
{
910-
Throw.ArgumentException(nameof(arguments), $"Missing required parameter '{parameter.Name}' for method '{parameter.Member.Name}'.");
1000+
Throw.ArgumentException(nameof(arguments), $"The arguments dictionary is missing a value for the required parameter '{parameter.Name}'.");
9111001
}
9121002

9131003
// Otherwise, use the optional parameter's default value.
9141004
return parameter.DefaultValue;
9151005
};
1006+
1007+
// Throws an ArgumentNullException indicating that AIFunctionArguments.Services must be provided.
1008+
static void ThrowNullServices(string parameterName) =>
1009+
Throw.ArgumentNullException($"arguments.{nameof(AIFunctionArguments.Services)}", $"Services are required for parameter '{parameterName}'.");
9161010
}
9171011

9181012
/// <summary>
@@ -1075,6 +1169,7 @@ private record struct DescriptorKey(
10751169
string? Description,
10761170
Func<ParameterInfo, AIFunctionFactoryOptions.ParameterBindingOptions>? GetBindParameterOptions,
10771171
Func<object?, Type?, CancellationToken, ValueTask<object?>>? MarshalResult,
1172+
IServiceProvider? Services,
10781173
AIJsonSchemaCreateOptions SchemaOptions);
10791174
}
10801175
}

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ public AIFunctionFactoryOptions()
106106
/// </remarks>
107107
public Func<object?, Type?, CancellationToken, ValueTask<object?>>? MarshalResult { get; set; }
108108

109+
/// <summary>
110+
/// Gets or sets optional services used in the construction of the <see cref="AIFunction"/>.
111+
/// </summary>
112+
/// <remarks>
113+
/// These services will be used to determine which parameters should be satisifed from dependency injection. As such,
114+
/// what services are satisfied via this provider should match what's satisfied via the provider passed into
115+
/// <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>.
116+
/// </remarks>
117+
public IServiceProvider? Services { get; set; }
118+
109119
/// <summary>Provides configuration options produced by the <see cref="ConfigureParameterBinding"/> delegate.</summary>
110120
public readonly record struct ParameterBindingOptions
111121
{

0 commit comments

Comments
 (0)