diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs index a906d57c870..db256e94916 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs @@ -4,7 +4,9 @@ using System; using System.Collections.Generic; using System.Diagnostics; +#if !NET9_0_OR_GREATER using System.Runtime.CompilerServices; +#endif using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -100,8 +102,8 @@ async Task GetResponseViaSharedAsync( ChatResponse? response = null; await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - response = await InnerClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - }, cancellationToken).ConfigureAwait(false); + response = await InnerClient.GetResponseAsync(messages, options, cancellationToken); + }, cancellationToken); if (response is null) { @@ -133,20 +135,19 @@ public override IAsyncEnumerable GetStreamingResponseAsync( { var updates = Channel.CreateBounded(1); -#pragma warning disable CA2016 // explicitly not forwarding the cancellation token, as we need to ensure the channel is always completed - _ = Task.Run(async () => -#pragma warning restore CA2016 + _ = ProcessAsync(); + async Task ProcessAsync() { Exception? error = null; try { await _sharedFunc(messages, options, async (messages, options, cancellationToken) => { - await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken)) { - await updates.Writer.WriteAsync(update, cancellationToken).ConfigureAwait(false); + await updates.Writer.WriteAsync(update, cancellationToken); } - }, cancellationToken).ConfigureAwait(false); + }, cancellationToken); } catch (Exception ex) { @@ -157,7 +158,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken { _ = updates.Writer.TryComplete(error); } - }); + } #if NET9_0_OR_GREATER return updates.Reader.ReadAllAsync(cancellationToken); @@ -166,7 +167,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken static async IAsyncEnumerable ReadAllAsync( ChannelReader channel, [EnumeratorCancellation] CancellationToken cancellationToken) { - while (await channel.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + while (await channel.WaitToReadAsync(cancellationToken)) { while (channel.TryRead(out var update)) { @@ -187,7 +188,7 @@ static async IAsyncEnumerable ReadAllAsync( static async IAsyncEnumerable GetStreamingResponseAsyncViaGetResponseAsync(Task task) { - ChatResponse response = await task.ConfigureAwait(false); + ChatResponse response = await task; foreach (var update in response.ToChatResponseUpdates()) { yield return update; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs index 61421b005e7..6fed2157b0b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -55,10 +55,10 @@ public override async Task GetResponseAsync( // concurrent callers might trigger duplicate requests, but that's acceptable. var cacheKey = GetCacheKey(messages, options, _boxedFalse); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result) + if (await ReadCacheAsync(cacheKey, cancellationToken) is not { } result) { - result = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); + result = await base.GetResponseAsync(messages, options, cancellationToken); + await WriteCacheAsync(cacheKey, result, cancellationToken); } return result; @@ -77,7 +77,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one. var cacheKey = GetCacheKey(messages, options, _boxedTrue); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatResponse) + if (await ReadCacheAsync(cacheKey, cancellationToken) is { } chatResponse) { // Yield all of the cached items. foreach (var chunk in chatResponse.ToChatResponseUpdates()) @@ -89,20 +89,20 @@ public override async IAsyncEnumerable GetStreamingResponseA { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken)) { capturedItems.Add(chunk); yield return chunk; } // Write the captured items to the cache as a non-streaming result. - await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken); } } else { var cacheKey = GetCacheKey(messages, options, _boxedTrue); - if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) + if (await ReadCacheStreamingAsync(cacheKey, cancellationToken) is { } existingChunks) { // Yield all of the cached items. string? chatThreadId = null; @@ -116,14 +116,14 @@ public override async IAsyncEnumerable GetStreamingResponseA { // Yield and store all of the items. List capturedItems = []; - await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken)) { capturedItems.Add(chunk); yield return chunk; } // Write the captured items to the cache. - await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false); + await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken); } } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs index b4e1e7f280f..a43bf5fac75 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 7ad8ea1d279..915b86b4ee3 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -221,7 +221,7 @@ public static async Task> GetResponseAsync( messages = [.. messages, promptAugmentation]; } - var result = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + var result = await chatClient.GetResponseAsync(messages, options, cancellationToken); return new ChatResponse(result, serializerOptions) { IsWrappedInObject = isWrappedInObject }; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs index 5a5dfea06c3..50da3928157 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs @@ -36,13 +36,13 @@ public ConfigureOptionsChatClient(IChatClient innerClient, Action c /// public override async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - await base.GetResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false); + await base.GetResponseAsync(messages, Configure(options), cancellationToken); /// public override async IAsyncEnumerable GetStreamingResponseAsync( IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken)) { yield return update; } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs index 2312eadcb0d..158c560de14 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs @@ -52,7 +52,7 @@ public JsonSerializerOptions JsonSerializerOptions _ = Throw.IfNull(key); _jsonSerializerOptions.MakeReadOnly(); - if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson) { return (ChatResponse?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse))); } @@ -66,7 +66,7 @@ public JsonSerializerOptions JsonSerializerOptions _ = Throw.IfNull(key); _jsonSerializerOptions.MakeReadOnly(); - if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson) { return (IReadOnlyList?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); } @@ -82,7 +82,7 @@ protected override async Task WriteCacheAsync(string key, ChatResponse value, Ca _jsonSerializerOptions.MakeReadOnly(); var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse))); - await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + await _storage.SetAsync(key, newJson, cancellationToken); } /// @@ -93,7 +93,7 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList _jsonSerializerOptions.MakeReadOnly(); var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); - await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + await _storage.SetAsync(key, newJson, cancellationToken); } /// Computes a cache key for the specified values. diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index ad88ba90265..6978a01dd44 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -13,7 +13,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Shared.Diagnostics; -using static Microsoft.Extensions.AI.OpenTelemetryConsts.GenAI; #pragma warning disable CA2213 // Disposable fields should be disposed #pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test @@ -233,7 +232,7 @@ public override async Task GetResponseAsync( functionCallContents?.Clear(); // Make the call to the inner client. - response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + response = await base.GetResponseAsync(messages, options, cancellationToken); if (response is null) { Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); @@ -279,7 +278,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -325,7 +324,7 @@ public override async IAsyncEnumerable GetStreamingResponseA updates.Clear(); functionCallContents?.Clear(); - await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken)) { if (update is null) { @@ -356,7 +355,7 @@ public override async IAsyncEnumerable GetStreamingResponseA FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -534,7 +533,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin if (functionCallContents.Count == 1) { FunctionInvocationResult result = await ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false); + messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken); IList added = CreateResponseMessages([result]); ThrowIfNoFunctionResultsAdded(added); @@ -549,13 +548,15 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin if (AllowConcurrentInvocation) { - // Schedule the invocation of every function. - // In this case we always capture exceptions because the ordering is nondeterministic + // Rather than await'ing each function before invoking the next, invoke all of them + // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, + // but if a function invocation completes asynchronously, its processing can overlap + // with the processing of other the other invocation invocations. results = await Task.WhenAll( from i in Enumerable.Range(0, functionCallContents.Count) - select Task.Run(() => ProcessFunctionCallAsync( + select ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, captureExceptions: true, cancellationToken))).ConfigureAwait(false); + iteration, i, captureExceptions: true, cancellationToken)); } else { @@ -565,7 +566,7 @@ select Task.Run(() => ProcessFunctionCallAsync( { results[i] = await ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false); + iteration, i, captureCurrentIterationExceptions, cancellationToken); } } @@ -663,7 +664,7 @@ private async Task ProcessFunctionCallAsync( object? result; try { - result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); + result = await InvokeFunctionAsync(context, cancellationToken); } catch (Exception e) when (!cancellationToken.IsCancellationRequested) { @@ -763,7 +764,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul try { CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit - result = await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false); + result = await context.Function.InvokeAsync(context.Arguments, cancellationToken); } catch (Exception e) { diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs index 51ca5a8f6d1..b5f43f5385b 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs @@ -60,7 +60,7 @@ public override async Task GetResponseAsync( try { - var response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + var response = await base.GetResponseAsync(messages, options, cancellationToken); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -127,7 +127,7 @@ public override async IAsyncEnumerable GetStreamingResponseA { try { - if (!await e.MoveNextAsync().ConfigureAwait(false)) + if (!await e.MoveNextAsync()) { break; } @@ -164,7 +164,7 @@ public override async IAsyncEnumerable GetStreamingResponseA } finally { - await e.DisposeAsync().ConfigureAwait(false); + await e.DisposeAsync(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs index df1717b4faa..c74bd3aa3c1 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -145,7 +145,7 @@ public override async Task GetResponseAsync( Exception? error = null; try { - response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + response = await base.GetResponseAsync(messages, options, cancellationToken); return response; } catch (Exception ex) @@ -183,7 +183,7 @@ public override async IAsyncEnumerable GetStreamingResponseA throw; } - var responseEnumerator = updates.ConfigureAwait(false).GetAsyncEnumerator(); + var responseEnumerator = updates.GetAsyncEnumerator(cancellationToken); List trackedUpdates = []; Exception? error = null; try diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs index 0f6c696bd0d..a3a068b9c34 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs @@ -39,6 +39,6 @@ public override async Task> GenerateAsync( { _ = Throw.IfNull(values); - return await _generateFunc(values, options, InnerGenerator, cancellationToken).ConfigureAwait(false); + return await _generateFunc(values, options, InnerGenerator, cancellationToken); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs index 43a983d7fd4..2c880d7a22c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -42,19 +42,19 @@ public override async Task> GenerateAsync( // In the expected common case where we can cheaply tell there's only a single value and access it, // we can avoid all the overhead of splitting the list and reassembling it. var cacheKey = GetCacheKey(valuesList[0], options); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding e) + if (await ReadCacheAsync(cacheKey, cancellationToken) is TEmbedding e) { return [e]; } else { - var generated = await base.GenerateAsync(valuesList, options, cancellationToken).ConfigureAwait(false); + var generated = await base.GenerateAsync(valuesList, options, cancellationToken); if (generated.Count != 1) { Throw.InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); } - await WriteCacheAsync(cacheKey, generated[0], cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(cacheKey, generated[0], cancellationToken); return generated; } } @@ -72,7 +72,7 @@ public override async Task> GenerateAsync( // concurrent callers might trigger duplicate requests, but that's acceptable. var cacheKey = GetCacheKey(input, options); - if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding existing) + if (await ReadCacheAsync(cacheKey, cancellationToken) is TEmbedding existing) { results.Add(existing); } @@ -87,12 +87,12 @@ public override async Task> GenerateAsync( if (uncached is not null) { // Now make a single call to the wrapped generator to generate embeddings for all of the uncached inputs. - var uncachedResults = await base.GenerateAsync(uncached.Select(e => e.Input), options, cancellationToken).ConfigureAwait(false); + var uncachedResults = await base.GenerateAsync(uncached.Select(e => e.Input), options, cancellationToken); // Store the resulting embeddings into the cache individually. for (int i = 0; i < uncachedResults.Count; i++) { - await WriteCacheAsync(uncached[i].CacheKey, uncachedResults[i], cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(uncached[i].CacheKey, uncachedResults[i], cancellationToken); } // Fill in the gaps with the newly generated results. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs index 8332064f22a..7d7ef140af7 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/ConfigureOptionsEmbeddingGenerator.cs @@ -46,7 +46,7 @@ public override async Task> GenerateAsync( EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) { - return await base.GenerateAsync(values, Configure(options), cancellationToken).ConfigureAwait(false); + return await base.GenerateAsync(values, Configure(options), cancellationToken); } /// Creates and configures the to pass along to the inner client. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs index d6c20ffb2f5..cd26879d040 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -57,7 +57,7 @@ public JsonSerializerOptions JsonSerializerOptions _ = Throw.IfNull(key); _jsonSerializerOptions.MakeReadOnly(); - if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson) { return JsonSerializer.Deserialize(existingJson, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); } @@ -73,7 +73,7 @@ protected override async Task WriteCacheAsync(string key, TEmbedding value, Canc _jsonSerializerOptions.MakeReadOnly(); var newJson = JsonSerializer.SerializeToUtf8Bytes(value, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); - await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + await _storage.SetAsync(key, newJson, cancellationToken); } /// Computes a cache key for the specified values. diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs index 84d4815cb23..751a5edd443 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderEmbeddingGeneratorExtensions.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs index 90553ca5411..924ee362633 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -62,7 +61,7 @@ public override async Task> GenerateAsync(IEnume try { - var embeddings = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + var embeddings = await base.GenerateAsync(values, options, cancellationToken); LogCompleted(embeddings.Count); diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs index f6983408b85..14332d1253f 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -104,7 +104,7 @@ public override async Task> GenerateAsync(IEnume Exception? error = null; try { - response = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + response = await base.GenerateAsync(values, options, cancellationToken); } catch (Exception ex) { diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index 6537f3aa3ab..41550ba0451 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -303,7 +303,7 @@ private ReflectionAIFunction( } return await FunctionDescriptor.ReturnParameterMarshaller( - ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken).ConfigureAwait(false); + ReflectionInvoke(FunctionDescriptor.Method, target, args), cancellationToken); } finally { @@ -311,7 +311,7 @@ private ReflectionAIFunction( { if (target is IAsyncDisposable ad) { - await ad.DisposeAsync().ConfigureAwait(false); + await ad.DisposeAsync(); } else if (target is IDisposable d) { @@ -599,14 +599,14 @@ static bool IsAsyncMethod(MethodInfo method) { return async (result, cancellationToken) => { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); - return await marshalResult(null, null, cancellationToken).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(result)); + return await marshalResult(null, null, cancellationToken); }; } return async static (result, _) => { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(result)); return null; }; } @@ -618,14 +618,14 @@ static bool IsAsyncMethod(MethodInfo method) { return async (result, cancellationToken) => { - await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); - return await marshalResult(null, null, cancellationToken).ConfigureAwait(false); + await ((ValueTask)ThrowIfNullResult(result)); + return await marshalResult(null, null, cancellationToken); }; } return async static (result, _) => { - await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); + await ((ValueTask)ThrowIfNullResult(result)); return null; }; } @@ -640,18 +640,18 @@ static bool IsAsyncMethod(MethodInfo method) { return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(taskObj)); object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false); + return await marshalResult(result, taskResultGetter.ReturnType, cancellationToken); }; } returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); return async (taskObj, cancellationToken) => { - await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); + await ((Task)ThrowIfNullResult(taskObj)); object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken); }; } @@ -666,9 +666,9 @@ static bool IsAsyncMethod(MethodInfo method) return async (taskObj, cancellationToken) => { var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task.ConfigureAwait(false); + await task; object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken).ConfigureAwait(false); + return await marshalResult(result, asTaskResultGetter.ReturnType, cancellationToken); }; } @@ -676,9 +676,9 @@ static bool IsAsyncMethod(MethodInfo method) return async (taskObj, cancellationToken) => { var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task.ConfigureAwait(false); + await task; object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + return await SerializeResultAsync(result, returnTypeInfo, cancellationToken); }; } } @@ -702,7 +702,7 @@ static bool IsAsyncMethod(MethodInfo method) // Serialize asynchronously to support potential IAsyncEnumerable responses. using PooledMemoryStream stream = new(); - await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); + await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken); Utf8JsonReader reader = new(stream.GetBuffer()); return JsonElement.ParseValue(ref reader); } diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index c851ccfb846..3b621827213 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -16,6 +16,16 @@ $(TargetFrameworks);netstandard2.0 $(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253 + + + $(NoWarn);CA2007 + true true diff --git a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs index 85833a3c171..1601b3c5073 100644 --- a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/ConfigureOptionsSpeechToTextClient.cs @@ -40,14 +40,14 @@ public ConfigureOptionsSpeechToTextClient(ISpeechToTextClient innerClient, Actio public override async Task GetTextAsync( Stream audioSpeechStream, SpeechToTextOptions? options = null, CancellationToken cancellationToken = default) { - return await base.GetTextAsync(audioSpeechStream, Configure(options), cancellationToken).ConfigureAwait(false); + return await base.GetTextAsync(audioSpeechStream, Configure(options), cancellationToken); } /// public override async IAsyncEnumerable GetStreamingTextAsync( Stream audioSpeechStream, SpeechToTextOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - await foreach (var update in base.GetStreamingTextAsync(audioSpeechStream, Configure(options), cancellationToken).ConfigureAwait(false)) + await foreach (var update in base.GetStreamingTextAsync(audioSpeechStream, Configure(options), cancellationToken)) { yield return update; } diff --git a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs index 4494d319dc0..6c5bf0ed929 100644 --- a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/LoggingSpeechToTextClient.cs @@ -63,7 +63,7 @@ public override async Task GetTextAsync( try { - var response = await base.GetTextAsync(audioSpeechStream, options, cancellationToken).ConfigureAwait(false); + var response = await base.GetTextAsync(audioSpeechStream, options, cancellationToken); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -130,7 +130,7 @@ public override async IAsyncEnumerable GetStreamingT { try { - if (!await e.MoveNextAsync().ConfigureAwait(false)) + if (!await e.MoveNextAsync()) { break; } @@ -167,7 +167,7 @@ public override async IAsyncEnumerable GetStreamingT } finally { - await e.DisposeAsync().ConfigureAwait(false); + await e.DisposeAsync(); } } diff --git a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs index 29569c55207..650282949f8 100644 --- a/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/SpeechToText/SpeechToTextClientBuilderSpeechToTextClientExtensions.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index 30332cb3e3c..67b2025b7de 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -122,15 +122,22 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn [Fact] public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() { - using var barrier = new Barrier(2); + int remaining = 2; + var tcs = new TaskCompletionSource(); var options = new ChatOptions { Tools = [ - AIFunctionFactory.Create((string arg) => + AIFunctionFactory.Create(async (string arg) => { - barrier.SignalAndWait(); + if (Interlocked.Decrement(ref remaining) == 0) + { + tcs.SetResult(true); + } + + await tcs.Task; + return arg + arg; }, "Func"), ] @@ -867,6 +874,62 @@ public async Task FunctionInvocations_PassesServices() await InvokeAndAssertAsync(options, plan, services: expected); } + [Fact] + public async Task FunctionInvocations_InvokedOnOriginalSynchronizationContext() + { + SynchronizationContext ctx = new CustomSynchronizationContext(); + SynchronizationContext.SetSynchronizationContext(ctx); + + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg"] = "value1" }), + new FunctionCallContent("callId2", "Func1", new Dictionary { ["arg"] = "value2" }), + ]), + new ChatMessage(ChatRole.Tool, + [ + new FunctionResultContent("callId2", result: "value1"), + new FunctionResultContent("callId2", result: "value2") + ]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + var options = new ChatOptions + { + Tools = [AIFunctionFactory.Create(async (string arg, CancellationToken cancellationToken) => + { + await Task.Delay(1, cancellationToken); + Assert.Same(ctx, SynchronizationContext.Current); + return arg; + }, "Func1")] + }; + + Func configurePipeline = builder => builder + .Use(async (messages, options, next, cancellationToken) => + { + await Task.Delay(1, cancellationToken); + await next(messages, options, cancellationToken); + }) + .UseOpenTelemetry() + .UseFunctionInvocation(configure: c => { c.AllowConcurrentInvocation = true; c.IncludeDetailedErrors = true; }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configurePipeline); + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configurePipeline); + } + + private sealed class CustomSynchronizationContext : SynchronizationContext + { + public override void Post(SendOrPostCallback d, object? state) + { + ThreadPool.QueueUserWorkItem(delegate + { + SetSynchronizationContext(this); + d(state); + }); + } + } + private static async Task> InvokeAndAssertAsync( ChatOptions options, List plan,