Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 106 additions & 11 deletions src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ public static partial class AIFunctionFactory
/// The handling of such parameters may be overridden via <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
/// </description>
/// </item>
/// <item>
/// <description>
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
/// rather than from the argument collection. The handling of such parameters may be overridden via
/// <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
/// </description>
/// </item>
/// </list>
/// All other parameter types are, by default, bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
/// and are included in the generated JSON schema. This may be overridden by the <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/> provided
Expand Down Expand Up @@ -168,6 +178,15 @@ public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? optio
/// must be non-<see langword="null"/>, or else the invocation will fail with an exception due to the required nature of the parameter.
/// </description>
/// </item>
/// <item>
/// <description>
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
/// rather than from the argument collection.
/// </description>
/// </item>
/// </list>
/// All other parameter types are bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
/// and are included in the generated JSON schema.
Expand Down Expand Up @@ -260,6 +279,16 @@ public static AIFunction Create(Delegate method, string? name = null, string? de
/// The handling of such parameters may be overridden via <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
/// </description>
/// </item>
/// <item>
/// <description>
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
/// rather than from the argument collection. The handling of such parameters may be overridden via
/// <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
/// </description>
/// </item>
/// </list>
/// All other parameter types are, by default, bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
/// and are included in the generated JSON schema. This may be overridden by the <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/> provided
Expand Down Expand Up @@ -357,6 +386,15 @@ public static AIFunction Create(MethodInfo method, object? target, AIFunctionFac
/// <see cref="AIFunctionArguments.Services"/> is allowed to be <see langword="null"/>; otherwise, <see cref="AIFunctionArguments.Services"/>
/// must be non-<see langword="null"/>, or else the invocation will fail with an exception due to the required nature of the parameter.
/// </description>
/// <item>
/// <description>
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
/// rather than from the argument collection.
/// </description>
/// </item>
/// </item>
/// </list>
/// All other parameter types are bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
Expand Down Expand Up @@ -465,6 +503,16 @@ public static AIFunction Create(MethodInfo method, object? target, string? name
/// The handling of such parameters may be overridden via <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
/// </description>
/// </item>
/// <item>
/// <description>
/// When the <see cref="AIFunction"/> is constructed, it may be passed an <see cref="IServiceProvider"/> via
/// <see cref="AIFunctionFactoryOptions.Services"/>. Any parameter that can be satisfied by that <see cref="IServiceProvider"/>
/// according to <see cref="IServiceProviderIsService"/> will not be included in the generated JSON schema and will be resolved
/// from the <see cref="IServiceProvider"/> provided to <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>,
/// rather than from the argument collection. The handling of such parameters may be overridden via
/// <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/>.
/// </description>
/// </item>
/// </list>
/// All other parameter types are, by default, bound from the <see cref="AIFunctionArguments"/> dictionary passed into <see cref="AIFunction.InvokeAsync"/>
/// and are included in the generated JSON schema. This may be overridden by the <see cref="AIFunctionFactoryOptions.ConfigureParameterBinding"/> provided
Expand Down Expand Up @@ -661,7 +709,7 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu
serializerOptions.MakeReadOnly();
ConcurrentDictionary<DescriptorKey, ReflectionAIFunctionDescriptor> innerCache = _descriptorCache.GetOrCreateValue(serializerOptions);

DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, schemaOptions);
DescriptorKey key = new(method, options.Name, options.Description, options.ConfigureParameterBinding, options.MarshalResult, options.Services, schemaOptions);
if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor))
{
return descriptor;
Expand All @@ -688,6 +736,8 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
}
}

IServiceProviderIsService? serviceProviderIsService = key.Services?.GetService<IServiceProviderIsService>();

// Use that binding information to impact the schema generation.
AIJsonSchemaCreateOptions schemaOptions = key.SchemaOptions with
{
Expand All @@ -714,6 +764,14 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
return false;
}

// We assume that if the services used to create the function support a particular type,
// so too do the services that will be passed into InvokeAsync. This is the same basic assumption
// made in ASP.NET.
if (serviceProviderIsService?.IsService(parameterInfo.ParameterType) is true)
{
return false;
}

// If there was an existing IncludeParameter delegate, now defer to it as we've
// excluded everything we need to exclude.
if (key.SchemaOptions.IncludeParameter is { } existingIncludeParameter)
Expand All @@ -735,7 +793,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
options = default;
}

ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i]);
ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, options, parameters[i], serviceProviderIsService);
}

// Get a marshaling delegate for the return value.
Expand Down Expand Up @@ -805,7 +863,8 @@ static bool IsAsyncMethod(MethodInfo method)
private static Func<AIFunctionArguments, CancellationToken, object?> GetParameterMarshaller(
JsonSerializerOptions serializerOptions,
AIFunctionFactoryOptions.ParameterBindingOptions bindingOptions,
ParameterInfo parameter)
ParameterInfo parameter,
IServiceProviderIsService? serviceProviderIsService)
{
if (string.IsNullOrWhiteSpace(parameter.Name))
{
Expand All @@ -831,28 +890,28 @@ static bool IsAsyncMethod(MethodInfo method)

// We're now into default handling of everything else.

// For AIFunctionArgument parameters, we bind to the arguments passed directly to InvokeAsync.
// For AIFunctionArgument parameters, we bind to the arguments passed to InvokeAsync.
if (parameterType == typeof(AIFunctionArguments))
{
return static (arguments, _) => arguments;
}

// For IServiceProvider parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments.
// For IServiceProvider parameters, we bind to the services passed to InvokeAsync via AIFunctionArguments.
if (parameterType == typeof(IServiceProvider))
{
return (arguments, _) =>
{
IServiceProvider? services = arguments.Services;
if (services is null && !parameter.HasDefaultValue)
if (!parameter.HasDefaultValue && services is null)
{
Throw.ArgumentException(nameof(arguments), $"An {nameof(IServiceProvider)} was not provided for the {parameter.Name} parameter.");
ThrowNullServices(parameter.Name);
}

return services;
};
}

// For [FromKeyedServices] parameters, we bind to the services passed directly to InvokeAsync via AIFunctionArguments.
// For [FromKeyedServices] parameters, we resolve from the services passed to InvokeAsync via AIFunctionArguments.
if (parameter.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is { } keyedAttr)
{
return (arguments, _) =>
Expand All @@ -864,7 +923,38 @@ static bool IsAsyncMethod(MethodInfo method)

if (!parameter.HasDefaultValue)
{
Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found.");
if (arguments.Services is null)
{
ThrowNullServices(parameter.Name);
}

Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' with key '{keyedAttr.Key}' was found for parameter '{parameter.Name}'.");
}

return parameter.DefaultValue;
};
}

// For any parameters that are satisfiable from the IServiceProvider, we resolve from the services passed to InvokeAsync
// via AIFunctionArguments. This is determined by the same same IServiceProviderIsService instance used to determine whether
// the parameter should be included in the schema.
if (serviceProviderIsService?.IsService(parameterType) is true)
{
return (arguments, _) =>
{
if (arguments.Services?.GetService(parameterType) is { } service)
{
return service;
}

if (!parameter.HasDefaultValue)
{
if (arguments.Services is null)
{
ThrowNullServices(parameter.Name);
}

Throw.ArgumentException(nameof(arguments), $"No service of type '{parameterType}' was found for parameter '{parameter.Name}'.");
}

return parameter.DefaultValue;
Expand All @@ -873,7 +963,7 @@ static bool IsAsyncMethod(MethodInfo method)

// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
JsonTypeInfo? typeInfo = serializerOptions.GetTypeInfo(parameterType);
return (arguments, _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
Expand Down Expand Up @@ -907,12 +997,16 @@ static bool IsAsyncMethod(MethodInfo method)
// If the parameter is required and there's no argument specified for it, throw.
if (!parameter.HasDefaultValue)
{
Throw.ArgumentException(nameof(arguments), $"Missing required parameter '{parameter.Name}' for method '{parameter.Member.Name}'.");
Throw.ArgumentException(nameof(arguments), $"The arguments dictionary is missing a value for the required parameter '{parameter.Name}'.");
}

// Otherwise, use the optional parameter's default value.
return parameter.DefaultValue;
};

// Throws an ArgumentNullException indicating that AIFunctionArguments.Services must be provided.
static void ThrowNullServices(string parameterName) =>
Throw.ArgumentNullException($"arguments.{nameof(AIFunctionArguments.Services)}", $"Services are required for parameter '{parameterName}'.");
}

/// <summary>
Expand Down Expand Up @@ -1075,6 +1169,7 @@ private record struct DescriptorKey(
string? Description,
Func<ParameterInfo, AIFunctionFactoryOptions.ParameterBindingOptions>? GetBindParameterOptions,
Func<object?, Type?, CancellationToken, ValueTask<object?>>? MarshalResult,
IServiceProvider? Services,
AIJsonSchemaCreateOptions SchemaOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ public AIFunctionFactoryOptions()
/// </remarks>
public Func<object?, Type?, CancellationToken, ValueTask<object?>>? MarshalResult { get; set; }

/// <summary>
/// Gets or sets optional services used in the construction of the <see cref="AIFunction"/>.
/// </summary>
/// <remarks>
/// These services will be used to determine which parameters should be satisifed from dependency injection. As such,
/// what services are satisfied via this provider should match what's satisfied via the provider passed into
/// <see cref="AIFunction.InvokeAsync"/> via <see cref="AIFunctionArguments.Services"/>.
/// </remarks>
public IServiceProvider? Services { get; set; }

/// <summary>Provides configuration options produced by the <see cref="ConfigureParameterBinding"/> delegate.</summary>
public readonly record struct ParameterBindingOptions
{
Expand Down
Loading
Loading