Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
#if !NET9_0_OR_GREATER
using System.Runtime.CompilerServices;
#endif
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
Expand Down Expand Up @@ -100,8 +102,8 @@ async Task<ChatResponse> GetResponseViaSharedAsync(
ChatResponse? response = null;
await _sharedFunc(messages, options, async (messages, options, cancellationToken) =>
{
response = await InnerClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
}, cancellationToken).ConfigureAwait(false);
response = await InnerClient.GetResponseAsync(messages, options, cancellationToken);
}, cancellationToken);

if (response is null)
{
Expand Down Expand Up @@ -133,20 +135,19 @@ public override IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
{
var updates = Channel.CreateBounded<ChatResponseUpdate>(1);

#pragma warning disable CA2016 // explicitly not forwarding the cancellation token, as we need to ensure the channel is always completed
_ = Task.Run(async () =>
#pragma warning restore CA2016
_ = ProcessAsync();
async Task ProcessAsync()
{
Exception? error = null;
try
{
await _sharedFunc(messages, options, async (messages, options, cancellationToken) =>
{
await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken))
{
await updates.Writer.WriteAsync(update, cancellationToken).ConfigureAwait(false);
await updates.Writer.WriteAsync(update, cancellationToken);
}
}, cancellationToken).ConfigureAwait(false);
}, cancellationToken);
}
catch (Exception ex)
{
Expand All @@ -157,7 +158,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken
{
_ = updates.Writer.TryComplete(error);
}
});
}

#if NET9_0_OR_GREATER
return updates.Reader.ReadAllAsync(cancellationToken);
Expand All @@ -166,7 +167,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken
static async IAsyncEnumerable<ChatResponseUpdate> ReadAllAsync(
ChannelReader<ChatResponseUpdate> channel, [EnumeratorCancellation] CancellationToken cancellationToken)
{
while (await channel.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
while (await channel.WaitToReadAsync(cancellationToken))
{
while (channel.TryRead(out var update))
{
Expand All @@ -187,7 +188,7 @@ static async IAsyncEnumerable<ChatResponseUpdate> ReadAllAsync(

static async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsyncViaGetResponseAsync(Task<ChatResponse> task)
{
ChatResponse response = await task.ConfigureAwait(false);
ChatResponse response = await task;
foreach (var update in response.ToChatResponseUpdates())
{
yield return update;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ public override async Task<ChatResponse> GetResponseAsync(
// concurrent callers might trigger duplicate requests, but that's acceptable.
var cacheKey = GetCacheKey(messages, options, _boxedFalse);

if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result)
if (await ReadCacheAsync(cacheKey, cancellationToken) is not { } result)
{
result = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
result = await base.GetResponseAsync(messages, options, cancellationToken);
await WriteCacheAsync(cacheKey, result, cancellationToken);
}

return result;
Expand All @@ -77,7 +77,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
// result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one.

var cacheKey = GetCacheKey(messages, options, _boxedTrue);
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatResponse)
if (await ReadCacheAsync(cacheKey, cancellationToken) is { } chatResponse)
{
// Yield all of the cached items.
foreach (var chunk in chatResponse.ToChatResponseUpdates())
Expand All @@ -89,20 +89,20 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
{
// Yield and store all of the items.
List<ChatResponseUpdate> capturedItems = [];
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken))
{
capturedItems.Add(chunk);
yield return chunk;
}

// Write the captured items to the cache as a non-streaming result.
await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken);
}
}
else
{
var cacheKey = GetCacheKey(messages, options, _boxedTrue);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken) is { } existingChunks)
{
// Yield all of the cached items.
string? chatThreadId = null;
Expand All @@ -116,14 +116,14 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
{
// Yield and store all of the items.
List<ChatResponseUpdate> capturedItems = [];
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken))
{
capturedItems.Add(chunk);
yield return chunk;
}

// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using Microsoft.Extensions.AI;
using Microsoft.Shared.Diagnostics;

namespace Microsoft.Extensions.AI;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public static async Task<ChatResponse<T>> GetResponseAsync<T>(
messages = [.. messages, promptAugmentation];
}

var result = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
var result = await chatClient.GetResponseAsync(messages, options, cancellationToken);
return new ChatResponse<T>(result, serializerOptions) { IsWrappedInObject = isWrappedInObject };
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ public ConfigureOptionsChatClient(IChatClient innerClient, Action<ChatOptions> c
/// <inheritdoc/>
public override async Task<ChatResponse> GetResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
await base.GetResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false);
await base.GetResponseAsync(messages, Configure(options), cancellationToken);

/// <inheritdoc/>
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false))
await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken))
{
yield return update;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public JsonSerializerOptions JsonSerializerOptions
_ = Throw.IfNull(key);
_jsonSerializerOptions.MakeReadOnly();

if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson)
if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson)
{
return (ChatResponse?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse)));
}
Expand All @@ -66,7 +66,7 @@ public JsonSerializerOptions JsonSerializerOptions
_ = Throw.IfNull(key);
_jsonSerializerOptions.MakeReadOnly();

if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson)
if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson)
{
return (IReadOnlyList<ChatResponseUpdate>?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList<ChatResponseUpdate>)));
}
Expand All @@ -82,7 +82,7 @@ protected override async Task WriteCacheAsync(string key, ChatResponse value, Ca
_jsonSerializerOptions.MakeReadOnly();

var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse)));
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
await _storage.SetAsync(key, newJson, cancellationToken);
}

/// <inheritdoc />
Expand All @@ -93,7 +93,7 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
_jsonSerializerOptions.MakeReadOnly();

var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList<ChatResponseUpdate>)));
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
await _storage.SetAsync(key, newJson, cancellationToken);
}

/// <summary>Computes a cache key for the specified values.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Shared.Diagnostics;
using static Microsoft.Extensions.AI.OpenTelemetryConsts.GenAI;

#pragma warning disable CA2213 // Disposable fields should be disposed
#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
Expand Down Expand Up @@ -233,7 +232,7 @@ public override async Task<ChatResponse> GetResponseAsync(
functionCallContents?.Clear();

// Make the call to the inner client.
response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
response = await base.GetResponseAsync(messages, options, cancellationToken);
if (response is null)
{
Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}.");
Expand Down Expand Up @@ -279,7 +278,7 @@ public override async Task<ChatResponse> GetResponseAsync(

// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false);
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

Expand Down Expand Up @@ -325,7 +324,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
updates.Clear();
functionCallContents?.Clear();

await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken))
{
if (update is null)
{
Expand Down Expand Up @@ -356,7 +355,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);

// Process all of the functions, adding their results into the history.
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false);
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;

Expand Down Expand Up @@ -534,7 +533,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
if (functionCallContents.Count == 1)
{
FunctionInvocationResult result = await ProcessFunctionCallAsync(
messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false);
messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken);

IList<ChatMessage> added = CreateResponseMessages([result]);
ThrowIfNoFunctionResultsAdded(added);
Expand All @@ -549,13 +548,15 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin

if (AllowConcurrentInvocation)
{
// Schedule the invocation of every function.
// In this case we always capture exceptions because the ordering is nondeterministic
// Rather than await'ing each function before invoking the next, invoke all of them
// and then await all of them. We avoid forcibly introducing parallelism via Task.Run,
// but if a function invocation completes asynchronously, its processing can overlap
// with the processing of other the other invocation invocations.
results = await Task.WhenAll(
from i in Enumerable.Range(0, functionCallContents.Count)
select Task.Run(() => ProcessFunctionCallAsync(
select ProcessFunctionCallAsync(
messages, options, functionCallContents,
iteration, i, captureExceptions: true, cancellationToken))).ConfigureAwait(false);
iteration, i, captureExceptions: true, cancellationToken));
}
else
{
Expand All @@ -565,7 +566,7 @@ select Task.Run(() => ProcessFunctionCallAsync(
{
results[i] = await ProcessFunctionCallAsync(
messages, options, functionCallContents,
iteration, i, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false);
iteration, i, captureCurrentIterationExceptions, cancellationToken);
}
}

Expand Down Expand Up @@ -663,7 +664,7 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
object? result;
try
{
result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false);
result = await InvokeFunctionAsync(context, cancellationToken);
}
catch (Exception e) when (!cancellationToken.IsCancellationRequested)
{
Expand Down Expand Up @@ -763,7 +764,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
try
{
CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
result = await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false);
result = await context.Function.InvokeAsync(context.Arguments, cancellationToken);
}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public override async Task<ChatResponse> GetResponseAsync(

try
{
var response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
var response = await base.GetResponseAsync(messages, options, cancellationToken);

if (_logger.IsEnabled(LogLevel.Debug))
{
Expand Down Expand Up @@ -127,7 +127,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
{
try
{
if (!await e.MoveNextAsync().ConfigureAwait(false))
if (!await e.MoveNextAsync())
{
break;
}
Expand Down Expand Up @@ -164,7 +164,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
}
finally
{
await e.DisposeAsync().ConfigureAwait(false);
await e.DisposeAsync();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public override async Task<ChatResponse> GetResponseAsync(
Exception? error = null;
try
{
response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
response = await base.GetResponseAsync(messages, options, cancellationToken);
return response;
}
catch (Exception ex)
Expand Down Expand Up @@ -183,7 +183,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
throw;
}

var responseEnumerator = updates.ConfigureAwait(false).GetAsyncEnumerator();
var responseEnumerator = updates.GetAsyncEnumerator(cancellationToken);
List<ChatResponseUpdate> trackedUpdates = [];
Exception? error = null;
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
{
_ = Throw.IfNull(values);

return await _generateFunc(values, options, InnerGenerator, cancellationToken).ConfigureAwait(false);
return await _generateFunc(values, options, InnerGenerator, cancellationToken);
}
}
Loading
Loading