Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Shared.Diagnostics;

#pragma warning disable CA2213 // Disposable fields should be disposed

namespace Microsoft.Extensions.AI;

/// <summary>
Expand All @@ -34,18 +38,28 @@ namespace Microsoft.Extensions.AI;
/// invocation requests to that same function.
/// </para>
/// </remarks>
public class FunctionInvokingChatClient : DelegatingChatClient
public partial class FunctionInvokingChatClient : DelegatingChatClient
{
/// <summary>The logger to use for logging information about function invocation.</summary>
private readonly ILogger _logger;

/// <summary>The <see cref="ActivitySource"/> to use for telemetry.</summary>
/// <remarks>This component does not own the instance and should not dispose it.</remarks>
private readonly ActivitySource? _activitySource;

/// <summary>Maximum number of roundtrips allowed to the inner client.</summary>
private int? _maximumIterationsPerRequest;

/// <summary>
/// Initializes a new instance of the <see cref="FunctionInvokingChatClient"/> class.
/// </summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
public FunctionInvokingChatClient(IChatClient innerClient)
/// <param name="logger">An <see cref="ILogger"/> to use for logging information about function invocation.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
: base(innerClient)
{
_logger = logger ?? NullLogger.Instance;
_activitySource = innerClient.GetService<ActivitySource>();
}

/// <summary>
Expand Down Expand Up @@ -562,13 +576,95 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
/// </param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>The result of the function invocation. This may be null if the function invocation returned null.</returns>
protected virtual Task<object?> InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken)
protected virtual async Task<object?> InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken)
{
_ = Throw.IfNull(context);

return context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken);
using Activity? activity = _activitySource?.StartActivity(context.Function.Metadata.Name);

long startingTimestamp = 0;
if (_logger.IsEnabled(LogLevel.Debug))
{
startingTimestamp = Stopwatch.GetTimestamp();
if (_logger.IsEnabled(LogLevel.Trace))
{
LogInvokingSensitive(context.Function.Metadata.Name, LoggingHelpers.AsJson(context.CallContent.Arguments, context.Function.Metadata.JsonSerializerOptions));
}
else
{
LogInvoking(context.Function.Metadata.Name);
}
}

object? result = null;
try
{
result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
if (activity is not null)
{
_ = activity.SetTag("error.type", e.GetType().FullName)
.SetStatus(ActivityStatusCode.Error, e.Message);
}

if (e is OperationCanceledException)
{
LogInvocationCanceled(context.Function.Metadata.Name);
}
else
{
LogInvocationFailed(context.Function.Metadata.Name, e);
}

throw;
}
finally
{
if (_logger.IsEnabled(LogLevel.Debug))
{
TimeSpan elapsed = GetElapsedTime(startingTimestamp);

if (result is not null && _logger.IsEnabled(LogLevel.Trace))
{
LogInvocationCompletedSensitive(context.Function.Metadata.Name, elapsed, LoggingHelpers.AsJson(result, context.Function.Metadata.JsonSerializerOptions));
}
else
{
LogInvocationCompleted(context.Function.Metadata.Name, elapsed);
}
}
}

return result;
}

private static TimeSpan GetElapsedTime(long startingTimestamp) =>
#if NET
Stopwatch.GetElapsedTime(startingTimestamp);
#else
new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency)));
#endif

[LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)]
private partial void LogInvoking(string methodName);

[LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)]
private partial void LogInvokingSensitive(string methodName, string arguments);

[LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)]
private partial void LogInvocationCompleted(string methodName, TimeSpan duration);

[LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)]
private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result);

[LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")]
private partial void LogInvocationCanceled(string methodName);

[LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")]
private partial void LogInvocationFailed(string methodName, Exception error);

/// <summary>Provides context for a function invocation.</summary>
public sealed class FunctionInvocationContext
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;
Expand All @@ -16,15 +18,21 @@ public static class FunctionInvokingChatClientBuilderExtensions
/// </summary>
/// <remarks>This works by adding an instance of <see cref="FunctionInvokingChatClient"/> with default options.</remarks>
/// <param name="builder">The <see cref="ChatClientBuilder"/> being used to build the chat pipeline.</param>
/// <param name="loggerFactory">An optional <see cref="ILoggerFactory"/> to use to create a logger for logging function invocations.</param>
/// <param name="configure">An optional callback that can be used to configure the <see cref="FunctionInvokingChatClient"/> instance.</param>
/// <returns>The supplied <paramref name="builder"/>.</returns>
public static ChatClientBuilder UseFunctionInvocation(this ChatClientBuilder builder, Action<FunctionInvokingChatClient>? configure = null)
public static ChatClientBuilder UseFunctionInvocation(
this ChatClientBuilder builder,
ILoggerFactory? loggerFactory = null,
Action<FunctionInvokingChatClient>? configure = null)
{
_ = Throw.IfNull(builder);

return builder.Use(innerClient =>
return builder.Use((services, innerClient) =>
{
var chatClient = new FunctionInvokingChatClient(innerClient);
loggerFactory ??= services.GetService<ILoggerFactory>();

var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
configure?.Invoke(chatClient);
return chatClient;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
}
}

private string AsJson<T>(T value) => JsonSerializer.Serialize(value, _jsonSerializerOptions.GetTypeInfo(typeof(T)));
private string AsJson<T>(T value) => LoggingHelpers.AsJson(value, _jsonSerializerOptions);

[LoggerMessage(LogLevel.Debug, "{MethodName} invoked.")]
private partial void LogInvoked(string methodName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S3358 // Ternary operators should not be nested

namespace Microsoft.Extensions.AI;

/// <summary>A delegating chat client that implements the OpenTelemetry Semantic Conventions for Generative AI systems.</summary>
Expand Down Expand Up @@ -106,6 +108,11 @@ protected override void Dispose(bool disposing)
/// </remarks>
public bool EnableSensitiveData { get; set; }

/// <inheritdoc/>
public override object? GetService(Type serviceType, object? serviceKey = null) =>
serviceType == typeof(ActivitySource) ? _activitySource :
base.GetService(serviceType, serviceKey);

/// <inheritdoc/>
public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -254,7 +261,7 @@ private static ChatCompletion ComposeStreamingUpdatesIntoChatCompletion(
string? modelId = options?.ModelId ?? _modelId;

activity = _activitySource.StartActivity(
$"{OpenTelemetryConsts.GenAI.Chat} {modelId}",
string.IsNullOrWhiteSpace(modelId) ? OpenTelemetryConsts.GenAI.Chat : $"{OpenTelemetryConsts.GenAI.Chat} {modelId}",
ActivityKind.Client);

if (activity is not null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,19 @@ public OpenTelemetryEmbeddingGenerator(IEmbeddingGenerator<TInput, TEmbedding> i
advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.OperationDuration.ExplicitBucketBoundaries });
}

/// <inheritdoc/>
public override object? GetService(Type serviceType, object? serviceKey = null) =>
serviceType == typeof(ActivitySource) ? _activitySource :
base.GetService(serviceType, serviceKey);

/// <inheritdoc/>
public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(IEnumerable<TInput> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(values);

using Activity? activity = CreateAndConfigureActivity();
using Activity? activity = CreateAndConfigureActivity(options);
Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null;
string? requestModelId = options?.ModelId ?? _modelId;

GeneratedEmbeddings<TEmbedding>? response = null;
Exception? error = null;
Expand All @@ -93,7 +99,7 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(IEnume
}
finally
{
TraceCompletion(activity, response, error, stopwatch);
TraceCompletion(activity, requestModelId, response, error, stopwatch);
}

return response;
Expand All @@ -112,18 +118,20 @@ protected override void Dispose(bool disposing)
}

/// <summary>Creates an activity for an embedding generation request, or returns null if not enabled.</summary>
private Activity? CreateAndConfigureActivity()
private Activity? CreateAndConfigureActivity(EmbeddingGenerationOptions? options)
{
Activity? activity = null;
if (_activitySource.HasListeners())
{
string? modelId = options?.ModelId ?? _modelId;

activity = _activitySource.StartActivity(
$"{OpenTelemetryConsts.GenAI.Embed} {_modelId}",
string.IsNullOrWhiteSpace(modelId) ? OpenTelemetryConsts.GenAI.Embed : $"{OpenTelemetryConsts.GenAI.Embed} {modelId}",
ActivityKind.Client,
default(ActivityContext),
[
new(OpenTelemetryConsts.GenAI.Operation.Name, OpenTelemetryConsts.GenAI.Embed),
new(OpenTelemetryConsts.GenAI.Request.Model, _modelId),
new(OpenTelemetryConsts.GenAI.Request.Model, modelId),
new(OpenTelemetryConsts.GenAI.SystemName, _modelProvider),
]);

Expand All @@ -149,6 +157,7 @@ protected override void Dispose(bool disposing)
/// <summary>Adds embedding generation response information to the activity.</summary>
private void TraceCompletion(
Activity? activity,
string? requestModelId,
GeneratedEmbeddings<TEmbedding>? embeddings,
Exception? error,
Stopwatch? stopwatch)
Expand All @@ -167,7 +176,7 @@ private void TraceCompletion(
if (_operationDurationHistogram.Enabled && stopwatch is not null)
{
TagList tags = default;
AddMetricTags(ref tags, responseModelId);
AddMetricTags(ref tags, requestModelId, responseModelId);
if (error is not null)
{
tags.Add(OpenTelemetryConsts.Error.Type, error.GetType().FullName);
Expand All @@ -180,7 +189,7 @@ private void TraceCompletion(
{
TagList tags = default;
tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "input");
AddMetricTags(ref tags, responseModelId);
AddMetricTags(ref tags, requestModelId, responseModelId);

_tokenUsageHistogram.Record(inputTokens.Value);
}
Expand All @@ -206,13 +215,13 @@ private void TraceCompletion(
}
}

private void AddMetricTags(ref TagList tags, string? responseModelId)
private void AddMetricTags(ref TagList tags, string? requestModelId, string? responseModelId)
{
tags.Add(OpenTelemetryConsts.GenAI.Operation.Name, OpenTelemetryConsts.GenAI.Embed);

if (_modelId is string requestModel)
if (requestModelId is not null)
{
tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModel);
tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModelId);
}

tags.Add(OpenTelemetryConsts.GenAI.SystemName, _modelProvider);
Expand Down
34 changes: 34 additions & 0 deletions src/Libraries/Microsoft.Extensions.AI/LoggingHelpers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#pragma warning disable CA1031 // Do not catch general exception types
#pragma warning disable S108 // Nested blocks of code should not be left empty
#pragma warning disable S2486 // Generic exceptions should not be ignored

using System.Text.Json;

namespace Microsoft.Extensions.AI;

/// <summary>Provides internal helpers for implementing logging.</summary>
internal static class LoggingHelpers
{
/// <summary>Serializes <paramref name="value"/> as JSON for logging purposes.</summary>
public static string AsJson<T>(T value, JsonSerializerOptions? options)
{
if (options?.TryGetTypeInfo(typeof(T), out var typeInfo) is true ||
AIJsonUtilities.DefaultOptions.TryGetTypeInfo(typeof(T), out typeInfo))
{
try
{
return JsonSerializer.Serialize(value, typeInfo);
}
catch
{
}
}

// If we're unable to get a type info for the value, or if we fail to serialize,
// return an empty JSON object. We do not want lack of type info to disrupt application behavior with exceptions.
return "{}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public sealed class TestChatClient : IChatClient

public Func<IList<ChatMessage>, ChatOptions?, CancellationToken, IAsyncEnumerable<StreamingChatCompletionUpdate>>? CompleteStreamingAsyncCallback { get; set; }

public Func<Type, object?, object?>? GetServiceCallback { get; set; }
public Func<Type, object?, object?> GetServiceCallback { get; set; } = (_, _) => null;

public Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
=> CompleteAsyncCallback!.Invoke(chatMessages, options, cancellationToken);
Expand All @@ -27,7 +27,7 @@ public IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(IL
=> CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken);

public object? GetService(Type serviceType, object? serviceKey = null)
=> GetServiceCallback!(serviceType, serviceKey);
=> GetServiceCallback(serviceType, serviceKey);

void IDisposable.Dispose()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ public sealed class TestEmbeddingGenerator : IEmbeddingGenerator<string, Embeddi

public Func<IEnumerable<string>, EmbeddingGenerationOptions?, CancellationToken, Task<GeneratedEmbeddings<Embedding<float>>>>? GenerateAsyncCallback { get; set; }

public Func<Type, object?, object?>? GetServiceCallback { get; set; }
public Func<Type, object?, object?> GetServiceCallback { get; set; } = (_, _) => null;

public Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
=> GenerateAsyncCallback!.Invoke(values, options, cancellationToken);

public object? GetService(Type serviceType, object? serviceKey = null)
=> GetServiceCallback!(serviceType, serviceKey);
=> GetServiceCallback(serviceType, serviceKey);

void IDisposable.Dispose()
{
Expand Down
Loading
Loading