diff --git a/.editorconfig b/.editorconfig index 0524a605..2254dff0 100644 --- a/.editorconfig +++ b/.editorconfig @@ -81,14 +81,14 @@ csharp_style_var_for_built_in_types = true csharp_style_var_when_type_is_apparent = true # Expression-bodied members -csharp_style_expression_bodied_accessors = true -csharp_style_expression_bodied_constructors = false -csharp_style_expression_bodied_indexers = true -csharp_style_expression_bodied_lambdas = true -csharp_style_expression_bodied_local_functions = false -csharp_style_expression_bodied_methods = when_on_single_line -csharp_style_expression_bodied_operators = false -csharp_style_expression_bodied_properties = true +csharp_style_expression_bodied_accessors = true:silent +csharp_style_expression_bodied_constructors = false:silent +csharp_style_expression_bodied_indexers = true:silent +csharp_style_expression_bodied_lambdas = true:silent +csharp_style_expression_bodied_local_functions = false:silent +csharp_style_expression_bodied_methods = when_on_single_line:silent +csharp_style_expression_bodied_operators = false:silent +csharp_style_expression_bodied_properties = true:silent # Pattern matching preferences csharp_style_pattern_matching_over_as_with_null_check = true @@ -106,11 +106,11 @@ csharp_prefer_static_local_function = true csharp_preferred_modifier_order = public,private,protected,internal,static,extern,new,virtual,abstract,sealed,override,readonly,unsafe,required,volatile,async # Code-block preferences -csharp_prefer_braces = true -csharp_prefer_simple_using_statement = true -csharp_style_namespace_declarations = file_scoped -csharp_style_prefer_method_group_conversion = true -csharp_style_prefer_top_level_statements = true +csharp_prefer_braces = true:silent +csharp_prefer_simple_using_statement = true:suggestion +csharp_style_namespace_declarations = file_scoped:silent +csharp_style_prefer_method_group_conversion = true:silent +csharp_style_prefer_top_level_statements = true:silent # Expression-level preferences csharp_prefer_simple_default_expression = true @@ -128,7 +128,7 @@ csharp_style_unused_value_assignment_preference = discard_variable csharp_style_unused_value_expression_statement_preference = discard_variable # 'using' directive preferences -csharp_using_directive_placement = outside_namespace +csharp_using_directive_placement = outside_namespace:silent # New line preferences csharp_style_allow_blank_line_after_colon_in_constructor_initializer_experimental = true @@ -263,4 +263,22 @@ dotnet_diagnostic.VSTHRD104.severity = none # Add .ConfigureAwait(bool) to your await expression dotnet_diagnostic.VSTHRD111.severity = none -dotnet_analyzer_diagnostic.severity = warning \ No newline at end of file +dotnet_analyzer_diagnostic.severity = warning +csharp_style_prefer_primary_constructors = true:suggestion +csharp_prefer_system_threading_lock = true:suggestion + +# CS0618: Type or member is obsolete +dotnet_diagnostic.CS0618.severity = warning + +[*.{cs,vb}] +dotnet_style_coalesce_expression = true:suggestion +dotnet_style_null_propagation = true:suggestion +dotnet_style_prefer_is_null_check_over_reference_equality_method = true:suggestion +dotnet_style_prefer_auto_properties = true:silent +dotnet_style_object_initializer = true:suggestion +dotnet_style_collection_initializer = true:suggestion +dotnet_style_operator_placement_when_wrapping = beginning_of_line +tab_width = 4 +indent_size = 4 +end_of_line = crlf +dotnet_style_prefer_simplified_boolean_expressions = true:suggestion \ No newline at end of file diff --git a/DevProxy.Abstractions/DevProxy.Abstractions.csproj b/DevProxy.Abstractions/DevProxy.Abstractions.csproj index 00286523..f14a06f4 100644 --- a/DevProxy.Abstractions/DevProxy.Abstractions.csproj +++ b/DevProxy.Abstractions/DevProxy.Abstractions.csproj @@ -1,4 +1,4 @@ - + net9.0 @@ -23,7 +23,6 @@ - diff --git a/DevProxy.Abstractions/Extensions/FuncExtensions.cs b/DevProxy.Abstractions/Extensions/FuncExtensions.cs deleted file mode 100644 index 4f0565db..00000000 --- a/DevProxy.Abstractions/Extensions/FuncExtensions.cs +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -// from: https://github.com/justcoding121/titanium-web-proxy/blob/902504a324425e4e49fc5ba604c2b7fa172e68ce/src/Titanium.Web.Proxy/Extensions/FuncExtensions.cs - -#pragma warning disable IDE0130 -namespace Titanium.Web.Proxy.EventArguments; -#pragma warning restore IDE0130 - -public static class FuncExtensions -{ - internal static async Task InvokeAsync(this AsyncEventHandler callback, object sender, T args, ExceptionHandler? exceptionFunc) - { - var invocationList = callback.GetInvocationList(); - - foreach (var @delegate in invocationList) - { - await InternalInvokeAsync((AsyncEventHandler)@delegate, sender, args, exceptionFunc); - } - } - - private static async Task InternalInvokeAsync(AsyncEventHandler callback, object sender, T e, ExceptionHandler? exceptionFunc) - { - try - { - await callback(sender, e); - } - catch (Exception ex) - { - exceptionFunc?.Invoke(new InvalidOperationException("Exception thrown in user event", ex)); - } - } -} \ No newline at end of file diff --git a/DevProxy.Abstractions/Extensions/ILoggerExtensions.cs b/DevProxy.Abstractions/Extensions/ILoggerExtensions.cs index dafa6aee..ffd06c57 100644 --- a/DevProxy.Abstractions/Extensions/ILoggerExtensions.cs +++ b/DevProxy.Abstractions/Extensions/ILoggerExtensions.cs @@ -7,8 +7,9 @@ namespace Microsoft.Extensions.Logging; public static class ILoggerExtensions { - public static void LogRequest(this ILogger logger, string message, MessageType messageType, LoggingContext? context = null) + public static void LogRequest(this ILogger logger, string message, MessageType messageType, object? context = null) { + ArgumentNullException.ThrowIfNull(logger); logger.Log(new RequestLog(message, messageType, context)); } @@ -17,6 +18,11 @@ public static void LogRequest(this ILogger logger, string message, MessageType m logger.Log(new RequestLog(message, messageType, method, url)); } + public static void LogRequest(this ILogger logger, string message, MessageType messageType, HttpRequestMessage httpRequestMessage, string? requestId = null, HttpResponseMessage? httpResponse = null) + { + logger.Log(new RequestLog(message, messageType, httpRequestMessage, requestId, httpResponse)); + } + public static void Log(this ILogger logger, RequestLog message) { ArgumentNullException.ThrowIfNull(logger); diff --git a/DevProxy.Abstractions/Plugins/BasePlugin.cs b/DevProxy.Abstractions/Plugins/BasePlugin.cs index d093f9da..9c126bb2 100644 --- a/DevProxy.Abstractions/Plugins/BasePlugin.cs +++ b/DevProxy.Abstractions/Plugins/BasePlugin.cs @@ -17,49 +17,57 @@ public abstract class BasePlugin( ILogger logger, ISet urlsToWatch) : IPlugin { + /// public bool Enabled { get; protected set; } = true; - protected ILogger Logger { get; } = logger; - protected ISet UrlsToWatch { get; } = urlsToWatch; + /// + /// List of URLs to watch for this plugin. + /// + public ISet UrlsToWatch { get; } = urlsToWatch; + /// public abstract string Name { get; } + protected ILogger Logger { get; } = logger; + /// + public virtual Func>? OnRequestAsync { get; } - public virtual Option[] GetOptions() => []; - public virtual Command[] GetCommands() => []; + /// + public virtual Func? ProvideRequestGuidanceAsync { get; } - public virtual Task InitializeAsync(InitArgs e, CancellationToken cancellationToken) - { - return Task.CompletedTask; - } + /// + public virtual Func>? OnResponseAsync { get; } - public virtual void OptionsLoaded(OptionsLoadedArgs e) - { - } + /// + public virtual Func? ProvideResponseGuidanceAsync { get; } - public virtual Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) - { - return Task.CompletedTask; - } + /// + public virtual Func? HandleRequestLogAsync { get; } - public virtual Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) - { - return Task.CompletedTask; - } + /// + public virtual Func? HandleRecordingStopAsync { get; } - public virtual Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) - { - return Task.CompletedTask; - } + /// + public virtual Option[] GetOptions() => []; + /// + public virtual Command[] GetCommands() => []; - public virtual Task AfterRequestLogAsync(RequestLogArgs e, CancellationToken cancellationToken) + /// + public virtual Task InitializeAsync(InitArgs e, CancellationToken cancellationToken) { return Task.CompletedTask; } - public virtual Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + /// + public virtual void OptionsLoaded(OptionsLoadedArgs e) { - return Task.CompletedTask; } + ///// + //public virtual Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + //{ + // return Task.CompletedTask; + //} + + /// public virtual Task MockRequestAsync(EventArgs e, CancellationToken cancellationToken) { return Task.CompletedTask; diff --git a/DevProxy.Abstractions/Plugins/BaseReportingPlugin.cs b/DevProxy.Abstractions/Plugins/BaseReportingPlugin.cs index 3439e13f..4b3006a5 100644 --- a/DevProxy.Abstractions/Plugins/BaseReportingPlugin.cs +++ b/DevProxy.Abstractions/Plugins/BaseReportingPlugin.cs @@ -11,18 +11,19 @@ namespace DevProxy.Abstractions.Plugins; public abstract class BaseReportingPlugin( ILogger logger, - ISet urlsToWatch) : BasePlugin(logger, urlsToWatch) + ISet urlsToWatch, + IProxyStorage proxyStorage) : BasePlugin(logger, urlsToWatch) { - protected virtual void StoreReport(object report, ProxyEventArgsBase e) + protected IProxyStorage ProxyStorage => proxyStorage; + protected virtual void StoreReport(object report) { - ArgumentNullException.ThrowIfNull(e); if (report is null) { return; } - ((Dictionary)e.GlobalData[ProxyUtils.ReportsKey])[Name] = report; + ((Dictionary)ProxyStorage.GlobalData[ProxyUtils.ReportsKey])[Name] = report; } } @@ -31,7 +32,8 @@ public abstract class BaseReportingPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection configurationSection) : + IConfigurationSection configurationSection, + IProxyStorage proxyStorage) : BasePlugin( httpClient, logger, @@ -39,15 +41,15 @@ public abstract class BaseReportingPlugin( proxyConfiguration, configurationSection) where TConfiguration : new() { - protected virtual void StoreReport(object report, ProxyEventArgsBase e) + protected IProxyStorage ProxyStorage => proxyStorage; + protected virtual void StoreReport(object report) { - ArgumentNullException.ThrowIfNull(e); if (report is null) { return; } - ((Dictionary)e.GlobalData[ProxyUtils.ReportsKey])[Name] = report; + ((Dictionary)ProxyStorage.GlobalData[ProxyUtils.ReportsKey])[Name] = report; } } diff --git a/DevProxy.Abstractions/Plugins/IPlugin.cs b/DevProxy.Abstractions/Plugins/IPlugin.cs index a86d6d71..3a77f6db 100644 --- a/DevProxy.Abstractions/Plugins/IPlugin.cs +++ b/DevProxy.Abstractions/Plugins/IPlugin.cs @@ -9,20 +9,96 @@ namespace DevProxy.Abstractions.Plugins; +/// +/// The interface that all plugins must implement. +/// +/// We made it easy for you with the public interface IPlugin { + /// + /// Name of the plugin. + /// string Name { get; } + + /// + /// Whether the plugin is enabled or not. + /// bool Enabled { get; } Option[] GetOptions(); Command[] GetCommands(); + /// + /// Called once after the plugin is constructed, but before any requests are handled. + /// + /// + /// + /// Task InitializeAsync(InitArgs e, CancellationToken cancellationToken); + + /// + /// Handles the event triggered when options are successfully loaded. + /// + /// An instance containing the event data, including the loaded options. void OptionsLoaded(OptionsLoadedArgs e); - Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken); - Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken); - Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken); - Task AfterRequestLogAsync(RequestLogArgs e, CancellationToken cancellationToken); - Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken); + + + /// + /// Implement this to handle requests. + /// + /// This is by default, so we can filter plugins based on implementation. + Func>? OnRequestAsync { get; } + + /// + /// Implement this to provide guidance for requests, you cannot modify the request or response here. + /// + /// This is by default, so we can filter plugins based on implementation. + Func? ProvideRequestGuidanceAsync { get; } + + /// + /// Implement this to modify responses from the remote server. + /// + /// This is by default, so we can filter plugins based on implementation. + Func>? OnResponseAsync { get; } + + /// + /// Implement this to provide guidance based on responses from the remote server. + /// + /// Think caching after the fact, combined with . This is by default, so we can filter plugins based on implementation. + Func? ProvideResponseGuidanceAsync { get; } + + /// + /// Implement this to receive RequestLog messages for each call. + /// + Func? HandleRequestLogAsync { get; } + + /// + /// Executes post-processing tasks after a recording has stopped. + /// + Func? HandleRecordingStopAsync { get; } + + /// + /// Receiving RequestLog messages for each call. + /// + /// This is for collecting log messages not requests itself + /// + /// + /// + //Task AfterRequestLogAsync(RequestLogArgs e, CancellationToken cancellationToken); + + /// + /// Executes post-processing tasks after a recording has stopped. + /// + /// The arguments containing details about the recording that has stopped. + /// A token to monitor for cancellation requests. + /// A task that represents the asynchronous operation. + //Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken); + + /// + /// + /// + /// + /// + /// Task MockRequestAsync(EventArgs e, CancellationToken cancellationToken); } diff --git a/DevProxy.Abstractions/Plugins/IProxyStorage.cs b/DevProxy.Abstractions/Plugins/IProxyStorage.cs new file mode 100644 index 00000000..e493b0ff --- /dev/null +++ b/DevProxy.Abstractions/Plugins/IProxyStorage.cs @@ -0,0 +1,23 @@ +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("DevProxy")] +namespace DevProxy.Abstractions.Plugins; + +/// +/// If you need either global or request-specific storage, ask for this interface in your plugin. +/// +public interface IProxyStorage +{ + /// + /// Access to global data shared across all requests. + /// + public Dictionary GlobalData { get; } + + /// + /// Get request-specific data by its ID. + /// + /// + /// + + public Dictionary GetRequestData(RequestId id); + + internal void RemoveRequestData(RequestId id); +} diff --git a/DevProxy.Abstractions/Plugins/PluginEvents.cs b/DevProxy.Abstractions/Plugins/PluginEvents.cs index de9298f0..7b59d80b 100644 --- a/DevProxy.Abstractions/Plugins/PluginEvents.cs +++ b/DevProxy.Abstractions/Plugins/PluginEvents.cs @@ -2,11 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Titanium.Web.Proxy.Http; namespace DevProxy.Abstractions.Plugins; -public class ThrottlerInfo(string throttlingKey, Func shouldThrottle, DateTime resetTime) +public class ThrottlerInfo(string throttlingKey, Func shouldThrottle, DateTime resetTime) { /// /// Time when the throttling window will be reset @@ -20,7 +19,7 @@ public class ThrottlerInfo(string throttlingKey, Func - public Func ShouldThrottle { get; private set; } = shouldThrottle ?? throw new ArgumentNullException(nameof(shouldThrottle)); + public Func ShouldThrottle { get; private set; } = shouldThrottle ?? throw new ArgumentNullException(nameof(shouldThrottle)); /// /// Throttling key used to identify which requests should be throttled. /// Can be set to a hostname, full URL or a custom string value, that diff --git a/DevProxy.Abstractions/Plugins/PluginResponse.cs b/DevProxy.Abstractions/Plugins/PluginResponse.cs new file mode 100644 index 00000000..e05d4799 --- /dev/null +++ b/DevProxy.Abstractions/Plugins/PluginResponse.cs @@ -0,0 +1,15 @@ +namespace DevProxy.Abstractions.Plugins; +public class PluginResponse +{ + public HttpRequestMessage? Request { get; private set; } + public HttpResponseMessage? Response { get; private set; } + private PluginResponse(HttpResponseMessage? response, HttpRequestMessage? request) + { + Response = response; + Request = request; + } + + public static PluginResponse Continue() => new(null, null); + public static PluginResponse Continue(HttpRequestMessage request) => new(null, request); + public static PluginResponse Respond(HttpResponseMessage response) => new(response, null); +} diff --git a/DevProxy.Abstractions/Plugins/RequestArguments.cs b/DevProxy.Abstractions/Plugins/RequestArguments.cs new file mode 100644 index 00000000..3b1c3584 --- /dev/null +++ b/DevProxy.Abstractions/Plugins/RequestArguments.cs @@ -0,0 +1,48 @@ +namespace DevProxy.Abstractions.Plugins; +/// +/// Represents the arguments for an HTTP request, including the request message and a unique request identifier. +/// +/// This class encapsulates the HTTP request message and its associated identifier, ensuring that both +/// are provided and accessible. The property contains the HTTP request details, while the property provides a unique identifier for tracking or logging purposes. +/// The HTTP request message to be sent. This parameter cannot be . +/// A unique identifier for the request. This parameter cannot be . +public class RequestArguments(HttpRequestMessage request, string requestId) +{ + /// + /// Incoming HTTP request message. + /// + public HttpRequestMessage Request { get; } = request; + + /// + /// Request identifier. + /// + public RequestId RequestId { get; } = requestId ?? throw new ArgumentNullException(nameof(requestId)); +} + +/// +/// Represents a unique identifier for a request. +/// +/// The type provides a strongly-typed representation of a request identifier, +/// encapsulating a string value. It supports implicit conversions to and from for ease of use, +/// and ensures that the identifier is not null. +/// +public record RequestId(string Id) +{ + private string Id { get; } = Id ?? throw new ArgumentNullException(nameof(Id)); + public static implicit operator string(RequestId requestId) + { + ArgumentNullException.ThrowIfNull(requestId); + return requestId.Id; + } + + public static implicit operator RequestId(string id) + { + return new(id); + } + + public static RequestId FromString(string id) + { + return new RequestId(id); + } +} \ No newline at end of file diff --git a/DevProxy.Abstractions/Plugins/ResponseArguments.cs b/DevProxy.Abstractions/Plugins/ResponseArguments.cs new file mode 100644 index 00000000..3803d27e --- /dev/null +++ b/DevProxy.Abstractions/Plugins/ResponseArguments.cs @@ -0,0 +1,5 @@ +namespace DevProxy.Abstractions.Plugins; +public class ResponseArguments(HttpRequestMessage request, HttpResponseMessage response, string requestId) : RequestArguments(request, requestId) +{ + public HttpResponseMessage Response { get; } = response; +} diff --git a/DevProxy.Abstractions/Proxy/IProxyLogger.cs b/DevProxy.Abstractions/Proxy/IProxyLogger.cs index 47106bbf..b7c34c2c 100644 --- a/DevProxy.Abstractions/Proxy/IProxyLogger.cs +++ b/DevProxy.Abstractions/Proxy/IProxyLogger.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Titanium.Web.Proxy.EventArguments; + namespace DevProxy.Abstractions.Proxy; @@ -22,8 +22,3 @@ public enum MessageType Processed, Timestamp } - -public class LoggingContext(SessionEventArgs session) -{ - public SessionEventArgs Session { get; } = session; -} \ No newline at end of file diff --git a/DevProxy.Abstractions/Proxy/ProxyEvents.cs b/DevProxy.Abstractions/Proxy/ProxyEvents.cs index 810387d3..49c37ada 100644 --- a/DevProxy.Abstractions/Proxy/ProxyEvents.cs +++ b/DevProxy.Abstractions/Proxy/ProxyEvents.cs @@ -2,10 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using DevProxy.Abstractions.Utils; using System.CommandLine; using System.Text.Json.Serialization; -using Titanium.Web.Proxy.EventArguments; namespace DevProxy.Abstractions.Proxy; @@ -15,16 +13,16 @@ public class ProxyEventArgsBase public Dictionary GlobalData { get; init; } = []; } -public class ProxyHttpEventArgsBase(SessionEventArgs session) : ProxyEventArgsBase +public class ProxyHttpEventArgsBase(object session) : ProxyEventArgsBase { - public SessionEventArgs Session { get; } = session ?? + public object Session { get; } = session ?? throw new ArgumentNullException(nameof(session)); - public bool HasRequestUrlMatch(ISet watchedUrls) => - ProxyUtils.MatchesUrlToWatch(watchedUrls, Session.HttpClient.Request.RequestUri.AbsoluteUri); + public static bool HasRequestUrlMatch(ISet _) => true; + //ProxyUtils.MatchesUrlToWatch(watchedUrls, Session.HttpClient.Request.RequestUri.AbsoluteUri); } -public class ProxyRequestArgs(SessionEventArgs session, ResponseState responseState) : +public class ProxyRequestArgs(object session, ResponseState responseState) : ProxyHttpEventArgsBase(session) { public ResponseState ResponseState { get; } = responseState ?? @@ -35,7 +33,7 @@ public bool ShouldExecute(ISet watchedUrls) => && HasRequestUrlMatch(watchedUrls); } -public class ProxyResponseArgs(SessionEventArgs session, ResponseState responseState) : +public class ProxyResponseArgs(object session, ResponseState responseState) : ProxyHttpEventArgsBase(session) { public ResponseState ResponseState { get; } = responseState ?? @@ -55,44 +53,61 @@ public class OptionsLoadedArgs(ParseResult parseResult) public class RequestLog { + //[JsonIgnore] + //public LoggingContext? Context { get; set; } [JsonIgnore] - public LoggingContext? Context { get; set; } + public HttpRequestMessage? Request { get; internal set; } + + [JsonIgnore] + public string? RequestId { get; internal set; } + + [JsonIgnore] + public HttpResponseMessage? Response { get; internal set; } + public string Message { get; set; } public MessageType MessageType { get; set; } public string? Method { get; init; } public string? PluginName { get; set; } public string? Url { get; init; } - public RequestLog(string message, MessageType messageType, LoggingContext? context) : - this(message, messageType, context?.Session.HttpClient.Request.Method, context?.Session.HttpClient.Request.Url, context) + public RequestLog(string message, MessageType messageType, object? context) { + throw new NotImplementedException("This constructor is not implemented. Use the other constructors instead."); + } + + public RequestLog(string message, MessageType messageType, HttpRequestMessage requestMessage, string? requestId = null, HttpResponseMessage? responseMessage = null) : + this(message, messageType, requestMessage?.Method.Method, requestMessage?.RequestUri!.AbsoluteUri, _: null) + { + Request = requestMessage; + Response = responseMessage; + RequestId = requestId; } public RequestLog(string message, MessageType messageType, string method, string url) : - this(message, messageType, method, url, context: null) + this(message, messageType, method, url, _: null) { } - private RequestLog(string message, MessageType messageType, string? method, string? url, LoggingContext? context) + private RequestLog(string message, MessageType messageType, string? method, string? url, object? _) { Message = message ?? throw new ArgumentNullException(nameof(message)); MessageType = messageType; - Context = context; + //Context = context; Method = method; Url = url; } - public void Deconstruct(out string message, out MessageType messageType, out LoggingContext? context, out string? method, out string? url) - { - message = Message; - messageType = MessageType; - context = Context; - method = Method; - url = Url; - } + //public void Deconstruct(out string message, out MessageType messageType, out LoggingContext? context, out string? method, out string? url) + //{ + // message = Message; + // messageType = MessageType; + // context = Context; + // method = Method; + // url = Url; + //} } -public class RecordingArgs(IEnumerable requestLogs) : ProxyEventArgsBase +public class RecordingArgs(IEnumerable requestLogs)// : ProxyEventArgsBase { public IEnumerable RequestLogs { get; set; } = requestLogs ?? throw new ArgumentNullException(nameof(requestLogs)); diff --git a/DevProxy.Abstractions/Utils/ProxyUtils.cs b/DevProxy.Abstractions/Utils/ProxyUtils.cs index ec21a389..a10c0ba1 100644 --- a/DevProxy.Abstractions/Utils/ProxyUtils.cs +++ b/DevProxy.Abstractions/Utils/ProxyUtils.cs @@ -12,7 +12,6 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; namespace DevProxy.Abstractions.Utils; @@ -84,14 +83,14 @@ static ProxyUtils() JsonSerializerOptions.Converters.Add(new JsonStringEnumConverter(JsonNamingPolicy.CamelCase)); } - public static bool IsGraphRequest(Request request) + public static bool IsGraphRequest(HttpRequestMessage request) { ArgumentNullException.ThrowIfNull(request); return IsGraphUrl(request.RequestUri); } - public static bool IsGraphUrl(Uri uri) + public static bool IsGraphUrl(Uri? uri) { ArgumentNullException.ThrowIfNull(uri); @@ -116,18 +115,18 @@ public static Uri GetAbsoluteRequestUrlFromBatch(Uri batchRequestUri, string rel return absoluteRequestUrl; } - public static bool IsSdkRequest(Request request) + public static bool IsSdkRequest(HttpRequestMessage request) { ArgumentNullException.ThrowIfNull(request); - return request.Headers.HeaderExists("SdkVersion"); + return request.Headers.Contains("SdkVersion"); } - public static bool IsGraphBetaRequest(Request request) => + public static bool IsGraphBetaRequest(HttpRequestMessage request) => IsGraphRequest(request) && IsGraphBetaUrl(request.RequestUri); - public static bool IsGraphBetaUrl(Uri uri) + public static bool IsGraphBetaUrl(Uri? uri) { ArgumentNullException.ThrowIfNull(uri); @@ -141,7 +140,7 @@ public static bool IsGraphBetaUrl(Uri uri) /// string a guid representing the a unique identifier for the request /// string representation of the date and time the request was made /// IList with defaults consistent with Microsoft Graph. Automatically adds CORS headers when the Origin header is present - public static IList BuildGraphResponseHeaders(Request request, string requestId, string requestDate) + public static IList BuildGraphResponseHeaders(HttpRequestMessage request, string requestId, string requestDate) { if (!IsGraphRequest(request)) { @@ -158,7 +157,7 @@ public static IList BuildGraphResponseHeaders(Request reques new ("Date", requestDate), new ("Content-Type", "application/json") }; - if (request.Headers.FirstOrDefault((h) => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase)) is not null) + if (request.Headers.Contains("Origin")) { headers.Add(new("Access-Control-Allow-Origin", "*")); headers.Add(new("Access-Control-Expose-Headers", "ETag, Location, Preference-Applied, Content-Range, request-id, client-request-id, ReadWriteConsistencyToken, SdkVersion, WWW-Authenticate, x-ms-client-gcc-tenant, Retry-After")); @@ -431,6 +430,13 @@ public static void MergeHeaders(IList allHeaders, IList watchedUrls, Uri? url, bool evaluateWildcards = false) + { + ArgumentNullException.ThrowIfNull(watchedUrls); + ArgumentNullException.ThrowIfNull(url); + return MatchesUrlToWatch(watchedUrls, url!.AbsoluteUri, evaluateWildcards); + } + public static bool MatchesUrlToWatch(ISet watchedUrls, string url, bool evaluateWildcards = false) { ArgumentNullException.ThrowIfNull(watchedUrls); diff --git a/DevProxy.Abstractions/packages.lock.json b/DevProxy.Abstractions/packages.lock.json index af377dc9..941198e6 100644 --- a/DevProxy.Abstractions/packages.lock.json +++ b/DevProxy.Abstractions/packages.lock.json @@ -95,28 +95,12 @@ "resolved": "2.0.0-beta5.25306.1", "contentHash": "ce0wuowuh13Cd7GXqLCq77/YWlxQMxrVCMIO/2/QUP6CdP/JWnlYSN/N3/55wwGsUwa9CvPuT8ddjgyypUr5ag==" }, - "Unobtanium.Web.Proxy": { - "type": "Direct", - "requested": "[0.1.5, )", - "resolved": "0.1.5", - "contentHash": "HiICGm0e44+i4aVHpLn+aphmSC2eQnDvlTttw1rE0hntOZKoLGRy37sydqqbRP1ZokMf3Mt0GEgSWxDwnucKGg==", - "dependencies": { - "BouncyCastle.Cryptography": "2.4.0", - "Microsoft.Extensions.Logging.Abstractions": "8.0.1", - "System.Runtime.CompilerServices.Unsafe": "6.0.0" - } - }, "YamlDotNet": { "type": "Direct", "requested": "[16.3.0, )", "resolved": "16.3.0", "contentHash": "SgMOdxbz8X65z8hraIs6hOEdnkH6hESTAIUa7viEngHOYaH+6q5XJmwr1+yb9vJpNQ19hCQY69xbFsLtXpobQA==" }, - "BouncyCastle.Cryptography": { - "type": "Transitive", - "resolved": "2.4.0", - "contentHash": "SwXsAV3sMvAU/Nn31pbjhWurYSjJ+/giI/0n6tCrYoupEK34iIHCuk3STAd9fx8yudM85KkLSVdn951vTng/vQ==" - }, "Microsoft.Data.Sqlite.Core": { "type": "Transitive", "resolved": "9.0.4", @@ -327,11 +311,6 @@ "resolved": "4.5.3", "contentHash": "3oDzvc/zzetpTKWMShs1AADwZjQ/36HnsufHRPcOjyRAAMLDlu2iD33MBI2opxnezcVUtXyqDXXjoFMOU9c7SA==" }, - "System.Runtime.CompilerServices.Unsafe": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg==" - }, "System.Text.Json": { "type": "Transitive", "resolved": "9.0.4", diff --git a/DevProxy.Plugins/Behavior/GenericRandomErrorPlugin.cs b/DevProxy.Plugins/Behavior/GenericRandomErrorPlugin.cs index 96c37658..caab8c11 100644 --- a/DevProxy.Plugins/Behavior/GenericRandomErrorPlugin.cs +++ b/DevProxy.Plugins/Behavior/GenericRandomErrorPlugin.cs @@ -15,8 +15,7 @@ using System.Net; using System.Text.Json; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; +using System.Text; namespace DevProxy.Plugins.Behavior; @@ -40,7 +39,8 @@ public sealed class GenericRandomErrorPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BasePlugin( httpClient, logger, @@ -108,62 +108,62 @@ public override void OptionsLoaded(OptionsLoadedArgs e) } } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) - { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - if (e.ResponseState.HasBeenSet) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return Task.FromResult(PluginResponse.Continue()); } var failMode = ShouldFail(); if (failMode == GenericRandomErrorFailMode.PassThru && Configuration.Rate != 100) { - Logger.LogRequest("Pass through", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Pass through", MessageType.Skipped, args.Request, args.RequestId); + return Task.FromResult(PluginResponse.Continue()); } - FailResponse(e); - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + var response = FailResponse(args.Request, args.RequestId); + if (response != null) + { + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return Task.FromResult(PluginResponse.Respond(response)); + } + + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return Task.FromResult(PluginResponse.Continue()); + }; // uses config to determine if a request should be failed private GenericRandomErrorFailMode ShouldFail() => _random.Next(1, 100) <= Configuration.Rate ? GenericRandomErrorFailMode.Random : GenericRandomErrorFailMode.PassThru; - private void FailResponse(ProxyRequestArgs e) + private HttpResponseMessage? FailResponse(HttpRequestMessage request, string requestId) { - var matchingResponse = GetMatchingErrorResponse(e.Session.HttpClient.Request); + var matchingResponse = GetMatchingErrorResponse(request); if (matchingResponse is not null && matchingResponse.Responses is not null) { // pick a random error response for the current request var error = matchingResponse.Responses.ElementAt(_random.Next(0, matchingResponse.Responses.Count())); - UpdateProxyResponse(e, error); + return UpdateProxyResponse(request, error, requestId); } else { - Logger.LogRequest("No matching error response found", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("No matching error response found", MessageType.Skipped, request, requestId); + return null; } } - private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) + private ThrottlingInfo ShouldThrottle(HttpRequestMessage request, string throttlingKey) { var throttleKeyForRequest = BuildThrottleKey(request); return new(throttleKeyForRequest == throttlingKey ? Configuration.RetryAfterInSeconds : 0, "Retry-After"); } - private GenericErrorResponse? GetMatchingErrorResponse(Request request) + private GenericErrorResponse? GetMatchingErrorResponse(HttpRequestMessage request) { if (Configuration.Errors is null || !Configuration.Errors.Any()) @@ -183,13 +183,13 @@ private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) return false; } - if (errorResponse.Request.Method != request.Method) + if (errorResponse.Request.Method != request.Method.Method) { return false; } - if (errorResponse.Request.Url == request.Url && - HasMatchingBody(errorResponse, request)) + if (errorResponse.Request.Url == request.RequestUri?.ToString() && + HasMatchingBodyAsync(errorResponse, request).GetAwaiter().GetResult()) { return true; } @@ -203,32 +203,34 @@ private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) // turn mock URL with wildcard into a regex and match against the request URL var errorResponseUrlRegex = Regex.Escape(errorResponse.Request.Url).Replace("\\*", ".*", StringComparison.OrdinalIgnoreCase); - return Regex.IsMatch(request.Url, $"^{errorResponseUrlRegex}$") && - HasMatchingBody(errorResponse, request); + return request.RequestUri != null && + Regex.IsMatch(request.RequestUri.ToString(), $"^{errorResponseUrlRegex}$") && + HasMatchingBodyAsync(errorResponse, request).GetAwaiter().GetResult(); }); return errorResponse; } - private void UpdateProxyResponse(ProxyRequestArgs e, GenericErrorResponseResponse error) + private HttpResponseMessage UpdateProxyResponse(HttpRequestMessage request, GenericErrorResponseResponse error, string requestId) { - var session = e.Session; - var request = session.HttpClient.Request; var headers = new List(); if (error.Headers is not null) { headers.AddRange(error.Headers); } + // Note: Global data handling for throttling is temporarily disabled + // This needs to be addressed with a proper service for managing throttled requests + // TODO: Implement proper throttling service for the new API if (error.StatusCode == (int)HttpStatusCode.TooManyRequests && error.Headers is not null && error.Headers.FirstOrDefault(h => h.Name is "Retry-After" or "retry-after")?.Value == "@dynamic") { var retryAfterDate = DateTime.Now.AddSeconds(Configuration.RetryAfterInSeconds); - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) + if (!proxyStorage.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) { value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + proxyStorage.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); } var throttledRequests = value as List; throttledRequests?.Add(new(BuildThrottleKey(request), ShouldThrottle, retryAfterDate)); @@ -240,6 +242,9 @@ error.Headers is not null && var statusCode = (HttpStatusCode)(error.StatusCode ?? 400); var body = error.Body is null ? string.Empty : JsonSerializer.Serialize(error.Body, ProxyUtils.JsonSerializerOptions); + + var response = new HttpResponseMessage(statusCode); + // we get a JSON string so need to start with the opening quote if (body.StartsWith("\"@")) { @@ -252,20 +257,39 @@ error.Headers is not null && if (!File.Exists(filePath)) { Logger.LogError("File {FilePath} not found. Serving file path in the mock response", (string?)filePath); - session.GenericResponse(body, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); + response.Content = new StringContent(body, Encoding.UTF8, "application/json"); } else { var bodyBytes = File.ReadAllBytes(filePath); - session.GenericResponse(bodyBytes, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); + response.Content = new ByteArrayContent(bodyBytes); } } else { - session.GenericResponse(body, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); + response.Content = new StringContent(body, Encoding.UTF8, "application/json"); + } + + // Add headers to response + foreach (var header in headers) + { + if (header.Name.Equals("Content-Type", StringComparison.OrdinalIgnoreCase)) + { + // Content-Type header goes on the content, not the response + if (response.Content != null) + { + _ = response.Content.Headers.Remove("Content-Type"); + response.Content.Headers.Add("Content-Type", header.Value); + } + } + else + { + response.Headers.Add(header.Name, header.Value); + } } - e.ResponseState.HasBeenSet = true; - Logger.LogRequest($"{error.StatusCode} {statusCode}", MessageType.Chaos, new(e.Session)); + + Logger.LogRequest($"{error.StatusCode} {statusCode}", MessageType.Chaos, request, requestId); + return response; } private void ValidateErrors() @@ -316,9 +340,9 @@ private void ValidateErrors() ); } - private static bool HasMatchingBody(GenericErrorResponse errorResponse, Request request) + private static async Task HasMatchingBodyAsync(GenericErrorResponse errorResponse, HttpRequestMessage request) { - if (request.Method == "GET") + if (request.Method == HttpMethod.Get) { // GET requests don't have a body so we can't match on it return true; @@ -330,16 +354,24 @@ private static bool HasMatchingBody(GenericErrorResponse errorResponse, Request return true; } - if (!request.HasBody || string.IsNullOrEmpty(request.BodyString)) + if (request.Content == null) + { + // error response defines a body fragment but the request has no body + // so it can't match + return false; + } + + var requestBody = await request.Content.ReadAsStringAsync(); + if (string.IsNullOrEmpty(requestBody)) { // error response defines a body fragment but the request has no body // so it can't match return false; } - return request.BodyString.Contains(errorResponse.Request.BodyFragment, StringComparison.OrdinalIgnoreCase); + return requestBody.Contains(errorResponse.Request.BodyFragment, StringComparison.OrdinalIgnoreCase); } // throttle requests per host - private static string BuildThrottleKey(Request r) => r.RequestUri.Host; + private static string BuildThrottleKey(HttpRequestMessage request) => request.RequestUri?.Host ?? ""; } diff --git a/DevProxy.Plugins/Behavior/GraphRandomErrorPlugin.cs b/DevProxy.Plugins/Behavior/GraphRandomErrorPlugin.cs index d7911602..4070b97b 100644 --- a/DevProxy.Plugins/Behavior/GraphRandomErrorPlugin.cs +++ b/DevProxy.Plugins/Behavior/GraphRandomErrorPlugin.cs @@ -4,7 +4,6 @@ using DevProxy.Abstractions.Proxy; using DevProxy.Abstractions.Plugins; -using DevProxy.Abstractions.Models; using DevProxy.Abstractions.Utils; using DevProxy.Plugins.Utils; using Microsoft.Extensions.Configuration; @@ -15,8 +14,6 @@ using System.Net; using System.Text.Json; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; using DevProxy.Plugins.Models; namespace DevProxy.Plugins.Behavior; @@ -50,6 +47,7 @@ public sealed class GraphRandomErrorPlugin( private const string _allowedErrorsOptionName = "--allowed-errors"; private const string _rateOptionName = "--failure-rate"; + private readonly Dictionary _methodStatusCode = new() { { @@ -168,176 +166,220 @@ public override void OptionsLoaded(OptionsLoadedArgs e) } } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); - - ArgumentNullException.ThrowIfNull(e); + ArgumentNullException.ThrowIfNull(args); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - var state = e.ResponseState; - if (state.HasBeenSet) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - if (!e.HasRequestUrlMatch(UrlsToWatch)) - { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return Task.FromResult(PluginResponse.Continue()); } var failMode = ShouldFail(); if (failMode == GraphRandomErrorFailMode.PassThru && Configuration.Rate != 100) { - Logger.LogRequest("Pass through", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Pass through", MessageType.Skipped, args.Request, args.RequestId); + return Task.FromResult(PluginResponse.Continue()); } - if (ProxyUtils.IsGraphBatchUrl(e.Session.HttpClient.Request.RequestUri)) + // If the request is a batch request, we will handle it in BeforeRequestAsync + if (ProxyUtils.IsGraphBatchUrl(args.Request.RequestUri!)) { - FailBatch(e); + //TODO: Build batch failure response + return Task.FromResult(PluginResponse.Continue()); } else { - FailResponse(e); + Logger.LogRequest("Pass through", MessageType.Chaos, args.Request); + return Task.FromResult(PluginResponse.Respond(FailResponse(args.Request, args.RequestId))); } - state.HasBeenSet = true; + }; - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + //public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + //{ + // Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + + // ArgumentNullException.ThrowIfNull(e); + + // var state = e.ResponseState; + // if (state.HasBeenSet) + // { + // Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); + // return Task.CompletedTask; + // } + // if (!e.HasRequestUrlMatch(UrlsToWatch)) // ProxyUtils.MatchesUrlToWatch(watchedUrls, Session.HttpClient.Request.RequestUri.AbsoluteUri); + // { + // Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + // return Task.CompletedTask; + // } + + // var failMode = ShouldFail(); + // if (failMode == GraphRandomErrorFailMode.PassThru && Configuration.Rate != 100) + // { + // Logger.LogRequest("Pass through", MessageType.Skipped, new(e.Session)); + // return Task.CompletedTask; + // } + // if (ProxyUtils.IsGraphBatchUrl(e.Session.HttpClient.Request.RequestUri)) + // { + // FailBatch(e); + // } + // else + // { + // //FailResponse(e); + // } + // state.HasBeenSet = true; + + // Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); + // return Task.CompletedTask; + //} // uses config to determine if a request should be failed private GraphRandomErrorFailMode ShouldFail() => _random.Next(1, 100) <= Configuration.Rate ? GraphRandomErrorFailMode.Random : GraphRandomErrorFailMode.PassThru; - private void FailResponse(ProxyRequestArgs e) - { - // pick a random error response for the current request method - var methodStatusCodes = _methodStatusCode[e.Session.HttpClient.Request.Method ?? "GET"]; - var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; - UpdateProxyResponse(e, errorStatus); - } + //private void FailResponse(ProxyRequestArgs e) + //{ + // // pick a random error response for the current request method + // var methodStatusCodes = _methodStatusCode[e.Session.HttpClient.Request.Method ?? "GET"]; + // var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; + // UpdateProxyResponse(e, errorStatus); + //} - private void FailBatch(ProxyRequestArgs e) + private HttpResponseMessage FailResponse(HttpRequestMessage e, string requestId) { - var batchResponse = new GraphBatchResponsePayload(); - - var batch = JsonSerializer.Deserialize(e.Session.HttpClient.Request.BodyString, ProxyUtils.JsonSerializerOptions); - if (batch == null) + var methodStatusCodes = _methodStatusCode[e.Method.Method ?? "GET"]; + var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; + var response = new HttpResponseMessage(errorStatus) { - UpdateProxyBatchResponse(e, batchResponse); - return; - } - var responses = new List(); - foreach (var request in batch.Requests) - { - try - { - // pick a random error response for the current request method - // if the request has dependencies, use FailedDependency status code - // https://learn.microsoft.com/en-us/graph/json-batching?tabs=http#sequencing-requests-with-the-dependson-property - var methodStatusCodes = _methodStatusCode[request.Method]; - var errorStatus = request.DependsOn is not null && request.DependsOn.Any() ? - HttpStatusCode.FailedDependency : - methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; - - var response = new GraphBatchResponsePayloadResponse + Content = new StringContent(JsonSerializer.Serialize(new GraphErrorResponseBody( + new() { - Id = request.Id, - Status = (int)errorStatus, - Body = new GraphBatchResponsePayloadResponseBody + Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + Message = BuildApiErrorMessage(e), + InnerError = new() { - Error = new() - { - Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), - Message = "Some error was generated by the proxy.", - } + RequestId = Guid.NewGuid().ToString(), + Date = DateTime.Now.ToString(CultureInfo.CurrentCulture) } - }; - - if (errorStatus == HttpStatusCode.TooManyRequests) - { - var retryAfterDate = DateTime.Now.AddSeconds(Configuration.RetryAfterInSeconds); - var requestUrl = ProxyUtils.GetAbsoluteRequestUrlFromBatch(e.Session.HttpClient.Request.RequestUri, request.Url); - - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) - { - value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); - } - var throttledRequests = value as List; - throttledRequests?.Add(new(GraphUtils.BuildThrottleKey(requestUrl), ShouldThrottle, retryAfterDate)); - response.Headers = new() { { "Retry-After", Configuration.RetryAfterInSeconds.ToString(CultureInfo.InvariantCulture) } }; - } - responses.Add(response); - } - catch { } - } - batchResponse.Responses = [.. responses]; - - UpdateProxyBatchResponse(e, batchResponse); - } - - private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) - { - var throttleKeyForRequest = GraphUtils.BuildThrottleKey(request); - return new(throttleKeyForRequest == throttlingKey ? Configuration.RetryAfterInSeconds : 0, "Retry-After"); - } - - private void UpdateProxyResponse(ProxyRequestArgs e, HttpStatusCode errorStatus) - { - var session = e.Session; - var requestId = Guid.NewGuid().ToString(); - var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); - var request = session.HttpClient.Request; - var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); - if (errorStatus == HttpStatusCode.TooManyRequests) - { - var retryAfterDate = DateTime.Now.AddSeconds(Configuration.RetryAfterInSeconds); - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) - { - value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); - } - - var throttledRequests = value as List; - throttledRequests?.Add(new(GraphUtils.BuildThrottleKey(request), ShouldThrottle, retryAfterDate)); - headers.Add(new("Retry-After", Configuration.RetryAfterInSeconds.ToString(CultureInfo.InvariantCulture))); - } - - var body = JsonSerializer.Serialize(new GraphErrorResponseBody( - new() - { - Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), - Message = BuildApiErrorMessage(request), - InnerError = new() - { - RequestId = requestId, - Date = requestDate - } - }), - ProxyUtils.JsonSerializerOptions - ); - Logger.LogRequest($"{(int)errorStatus} {errorStatus}", MessageType.Chaos, new(e.Session)); - session.GenericResponse(body ?? string.Empty, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); - } - - private void UpdateProxyBatchResponse(ProxyRequestArgs ev, GraphBatchResponsePayload response) - { - // failed batch uses 200 OK status code - var errorStatus = HttpStatusCode.OK; - - var session = ev.Session; - var requestId = Guid.NewGuid().ToString(); - var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); - var request = session.HttpClient.Request; - var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); - - var body = JsonSerializer.Serialize(response, ProxyUtils.JsonSerializerOptions); - Logger.LogRequest($"{(int)errorStatus} {errorStatus}", MessageType.Chaos, new(ev.Session)); - session.GenericResponse(body, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); + }), ProxyUtils.JsonSerializerOptions)) + }; + Logger.LogRequest($"{(int)errorStatus} {errorStatus}", MessageType.Chaos, e, requestId); + return response; } - private static string BuildApiErrorMessage(Request r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage()) : "")}"; + //private void FailBatch(ProxyRequestArgs e) + //{ + // var batchResponse = new GraphBatchResponsePayload(); + + // var batch = JsonSerializer.Deserialize(e.Session.HttpClient.Request.BodyString, ProxyUtils.JsonSerializerOptions); + // if (batch == null) + // { + // UpdateProxyBatchResponse(e, batchResponse); + // return; + // } + + // var responses = new List(); + // foreach (var request in batch.Requests) + // { + // try + // { + // // pick a random error response for the current request method + // var methodStatusCodes = _methodStatusCode[request.Method]; + // var errorStatus = methodStatusCodes[_random.Next(0, methodStatusCodes.Length)]; + + // var response = new GraphBatchResponsePayloadResponse + // { + // Id = request.Id, + // Status = (int)errorStatus, + // Body = new GraphBatchResponsePayloadResponseBody + // { + // Error = new() + // { + // Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + // Message = "Some error was generated by the proxy.", + // } + // } + // }; + + // if (errorStatus == HttpStatusCode.TooManyRequests) + // { + // var retryAfterDate = DateTime.Now.AddSeconds(Configuration.RetryAfterInSeconds); + // var requestUrl = ProxyUtils.GetAbsoluteRequestUrlFromBatch(e.Session.HttpClient.Request.RequestUri, request.Url); + // var throttledRequests = e.GlobalData[RetryAfterPlugin.ThrottledRequestsKey] as List; + // throttledRequests?.Add(new(GraphUtils.BuildThrottleKey(requestUrl), ShouldThrottle, retryAfterDate)); + // response.Headers = new() { { "Retry-After", Configuration.RetryAfterInSeconds.ToString(CultureInfo.InvariantCulture) } }; + // } + + // responses.Add(response); + // } + // catch { } + // } + // batchResponse.Responses = [.. responses]; + + // UpdateProxyBatchResponse(e, batchResponse); + //} + + //private ThrottlingInfo ShouldThrottle(HttpRequestMessage request, string throttlingKey) + //{ + // var throttleKeyForRequest = GraphUtils.BuildThrottleKey(request); + // return new(throttleKeyForRequest == throttlingKey ? Configuration.RetryAfterInSeconds : 0, "Retry-After"); + //} + + //private void UpdateProxyResponse(ProxyRequestArgs e, HttpStatusCode errorStatus) + //{ + // var session = e.Session; + // var requestId = Guid.NewGuid().ToString(); + // var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); + // var request = session.HttpClient.Request; + // var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); + // if (errorStatus == HttpStatusCode.TooManyRequests) + // { + // var retryAfterDate = DateTime.Now.AddSeconds(Configuration.RetryAfterInSeconds); + // if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) + // { + // value = new List(); + // e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + // } + + // var throttledRequests = value as List; + // throttledRequests?.Add(new(GraphUtils.BuildThrottleKey(request), ShouldThrottle, retryAfterDate)); + // headers.Add(new("Retry-After", Configuration.RetryAfterInSeconds.ToString(CultureInfo.InvariantCulture))); + // } + + // var body = JsonSerializer.Serialize(new GraphErrorResponseBody( + // new() + // { + // Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + // Message = BuildApiErrorMessage(request), + // InnerError = new() + // { + // RequestId = requestId, + // Date = requestDate + // } + // }), + // ProxyUtils.JsonSerializerOptions + // ); + // Logger.LogRequest($"{(int)errorStatus} {errorStatus}", MessageType.Chaos, new(e.Session)); + // session.GenericResponse(body ?? string.Empty, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); + //} + + //private void UpdateProxyBatchResponse(ProxyRequestArgs ev, GraphBatchResponsePayload response) + //{ + // // failed batch uses a fixed 424 error status code + // var errorStatus = HttpStatusCode.FailedDependency; + + // var session = ev.Session; + // var requestId = Guid.NewGuid().ToString(); + // var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); + // var request = session.HttpClient.Request; + // var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); + + // var body = JsonSerializer.Serialize(response, ProxyUtils.JsonSerializerOptions); + // Logger.LogRequest($"{(int)errorStatus} {errorStatus}", MessageType.Chaos, new(ev.Session)); + // session.GenericResponse(body, errorStatus, headers.Select(h => new HttpHeader(h.Name, h.Value))); + //} + + private static string BuildApiErrorMessage(HttpRequestMessage r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage()) : "")}"; } diff --git a/DevProxy.Plugins/Behavior/LanguageModelFailurePlugin.cs b/DevProxy.Plugins/Behavior/LanguageModelFailurePlugin.cs index 9995208f..ed73fbc7 100644 --- a/DevProxy.Plugins/Behavior/LanguageModelFailurePlugin.cs +++ b/DevProxy.Plugins/Behavior/LanguageModelFailurePlugin.cs @@ -47,51 +47,43 @@ public sealed class LanguageModelFailurePlugin( public override string Name => nameof(LanguageModelFailurePlugin); - public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) - { - Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); - return; - } - if (e.ResponseState.HasBeenSet) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("Response already set", MessageType.Skipped, new LoggingContext(e.Session)); - return; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return PluginResponse.Continue(); } - var request = e.Session.HttpClient.Request; - if (request.Method is null || - !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || - !request.HasBody) + if (args.Request.Method != HttpMethod.Post || args.Request.Content == null) { - Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, args.Request, args.RequestId); + return PluginResponse.Continue(); } - if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest)) + var requestBody = await args.Request.Content.ReadAsStringAsync(cancellationToken); + if (!TryGetOpenAIRequest(requestBody, out var openAiRequest)) { - Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, args.Request, args.RequestId); + return PluginResponse.Continue(); } var (faultName, faultPrompt) = GetFault(); if (faultPrompt is null) { Logger.LogError("Failed to get fault prompt. Passing request as-is."); - return; + return PluginResponse.Continue(); } + string modifiedRequestBody; if (openAiRequest is OpenAICompletionRequest completionRequest) { completionRequest.Prompt += "\n\n" + faultPrompt; Logger.LogDebug("Modified completion request prompt: {Prompt}", completionRequest.Prompt); - Logger.LogRequest($"Simulating fault {faultName}", MessageType.Chaos, new(e.Session)); - e.Session.SetRequestBodyString(JsonSerializer.Serialize(completionRequest, ProxyUtils.JsonSerializerOptions)); + Logger.LogRequest($"Simulating fault {faultName}", MessageType.Chaos, args.Request, args.RequestId); + modifiedRequestBody = JsonSerializer.Serialize(completionRequest, ProxyUtils.JsonSerializerOptions); } else if (openAiRequest is OpenAIChatCompletionRequest chatRequest) { @@ -113,18 +105,43 @@ public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationTo }; Logger.LogDebug("Added fault prompt to messages: {Prompt}", faultPrompt); - Logger.LogRequest($"Simulating fault {faultName}", MessageType.Chaos, new(e.Session)); - e.Session.SetRequestBodyString(JsonSerializer.Serialize(newRequest, ProxyUtils.JsonSerializerOptions)); + Logger.LogRequest($"Simulating fault {faultName}", MessageType.Chaos, args.Request, args.RequestId); + modifiedRequestBody = JsonSerializer.Serialize(newRequest, ProxyUtils.JsonSerializerOptions); } else { Logger.LogDebug("Unknown OpenAI request type. Passing request as-is."); + return PluginResponse.Continue(); } - await Task.CompletedTask; + // Create new request with modified body + var modifiedRequest = new HttpRequestMessage(args.Request.Method, args.Request.RequestUri) + { + Content = new StringContent(modifiedRequestBody, System.Text.Encoding.UTF8, "application/json") + }; + + // Copy headers from original request + foreach (var header in args.Request.Headers) + { + _ = modifiedRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - } + // Copy content headers if they exist + if (args.Request.Content?.Headers != null) + { + foreach (var header in args.Request.Content.Headers) + { + if (!header.Key.Equals("Content-Type", StringComparison.OrdinalIgnoreCase) && + !header.Key.Equals("Content-Length", StringComparison.OrdinalIgnoreCase)) + { + _ = modifiedRequest.Content.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + } + } + + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return PluginResponse.Continue(modifiedRequest); + }; private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) { diff --git a/DevProxy.Plugins/Behavior/LanguageModelRateLimitingPlugin.cs b/DevProxy.Plugins/Behavior/LanguageModelRateLimitingPlugin.cs index 5dcc2121..ff044326 100644 --- a/DevProxy.Plugins/Behavior/LanguageModelRateLimitingPlugin.cs +++ b/DevProxy.Plugins/Behavior/LanguageModelRateLimitingPlugin.cs @@ -12,9 +12,8 @@ using Microsoft.Extensions.Logging; using System.Globalization; using System.Net; +using System.Text; using System.Text.Json; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Behavior; @@ -40,7 +39,8 @@ public sealed class LanguageModelRateLimitingPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BasePlugin( httpClient, logger, @@ -48,8 +48,7 @@ public sealed class LanguageModelRateLimitingPlugin( proxyConfiguration, pluginConfigurationSection) { - // initial values so that we know when we intercept the - // first request and can set the initial values + private readonly IProxyStorage _proxyStorage = proxyStorage; private int _promptTokensRemaining = -1; private int _completionTokensRemaining = -1; private DateTime _resetTime = DateTime.MinValue; @@ -71,38 +70,27 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell } } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - var session = e.Session; - var state = e.ResponseState; - if (state.HasBeenSet) - { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - var request = e.Session.HttpClient.Request; - if (request.Method is null || - !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || - !request.HasBody) + if (args.Request.Method != HttpMethod.Post || args.Request.Content == null) { - Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest)) + var bodyString = await args.Request.Content.ReadAsStringAsync(cancellationToken); + if (!TryGetOpenAIRequest(bodyString, out var openAiRequest)) { - Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } // set the initial values for the first request @@ -127,71 +115,83 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca // check if we have tokens available if (_promptTokensRemaining <= 0 || _completionTokensRemaining <= 0) { - Logger.LogRequest($"Exceeded token limit when calling {request.Url}. Request will be throttled", MessageType.Failed, new(e.Session)); + Logger.LogRequest($"Exceeded token limit when calling {args.Request.RequestUri}. Request will be throttled", MessageType.Failed, args.Request); if (Configuration.WhenLimitExceeded == TokenLimitResponseWhenExceeded.Throttle) { - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) + // Add throttling info to global data for RetryAfterPlugin coordination + if (!_proxyStorage.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) { value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + _proxyStorage.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); } var throttledRequests = value as List; throttledRequests?.Add(new( - BuildThrottleKey(request), + BuildThrottleKey(args.Request), ShouldThrottle, _resetTime )); - ThrottleResponse(e); - state.HasBeenSet = true; + + return PluginResponse.Respond(BuildThrottleResponse(args.Request)); } else { if (Configuration.CustomResponse is not null) { - var headersList = Configuration.CustomResponse.Headers is not null ? - Configuration.CustomResponse.Headers.Select(h => new HttpHeader(h.Name, h.Value)).ToList() : - []; - - var retryAfterHeader = headersList.FirstOrDefault(h => h.Name.Equals(Configuration.HeaderRetryAfter, StringComparison.OrdinalIgnoreCase)); - if (retryAfterHeader is not null && retryAfterHeader.Value == "@dynamic") - { - headersList.Add(new(Configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture))); - _ = headersList.Remove(retryAfterHeader); - } - - var headers = headersList.ToArray(); - - // allow custom throttling response var responseCode = (HttpStatusCode)(Configuration.CustomResponse.StatusCode ?? 200); + + // Add throttling info for TooManyRequests responses if (responseCode == HttpStatusCode.TooManyRequests) { - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) + if (!_proxyStorage.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) { value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + _proxyStorage.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); } var throttledRequests = value as List; throttledRequests?.Add(new( - BuildThrottleKey(request), + BuildThrottleKey(args.Request), ShouldThrottle, _resetTime )); } - string body = Configuration.CustomResponse.Body is not null ? - JsonSerializer.Serialize(Configuration.CustomResponse.Body, ProxyUtils.JsonSerializerOptions) : - ""; - e.Session.GenericResponse(body, responseCode, headers); - state.HasBeenSet = true; + var response = new HttpResponseMessage(responseCode) + { + Content = new StringContent( + Configuration.CustomResponse.Body is not null ? + JsonSerializer.Serialize(Configuration.CustomResponse.Body, ProxyUtils.JsonSerializerOptions) : + string.Empty, + Encoding.UTF8, + "application/json") + }; + + // Add headers + if (Configuration.CustomResponse.Headers is not null) + { + foreach (var header in Configuration.CustomResponse.Headers) + { + var headerValue = header.Value; + if (header.Name.Equals(Configuration.HeaderRetryAfter, StringComparison.OrdinalIgnoreCase) && headerValue == "@dynamic") + { + headerValue = ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture); + } + _ = response.Headers.TryAddWithoutValidation(header.Name, headerValue); + } + } + + return PluginResponse.Respond(response); } else { - Logger.LogRequest($"Custom behavior not set. {Configuration.CustomResponseFile} not found.", MessageType.Failed, new(e.Session)); - e.Session.GenericResponse("Custom response file not found.", HttpStatusCode.InternalServerError, []); - state.HasBeenSet = true; + Logger.LogRequest($"Custom behavior not set. {Configuration.CustomResponseFile} not found.", MessageType.Failed, args.Request); + var response = new HttpResponseMessage(HttpStatusCode.InternalServerError) + { + Content = new StringContent("Custom response file not found.", Encoding.UTF8, "text/plain") + }; + return PluginResponse.Respond(response); } } } @@ -200,41 +200,37 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca Logger.LogDebug("Tokens remaining - Prompt: {PromptTokensRemaining}, Completion: {CompletionTokensRemaining}", _promptTokensRemaining, _completionTokensRemaining); } - return Task.CompletedTask; - } + return PluginResponse.Continue(); + }; - public override Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + public override Func>? OnResponseAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeResponseAsync)); - - ArgumentNullException.ThrowIfNull(e); + Logger.LogTrace("{Method} called", nameof(OnResponseAsync)); - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return null; } - var request = e.Session.HttpClient.Request; - if (request.Method is null || - !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || - !request.HasBody) + if (args.Request.Method != HttpMethod.Post || args.Request.Content == null) { Logger.LogDebug("Skipping non-POST request"); - return Task.CompletedTask; + return null; } - if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest)) + var bodyString = await args.Request.Content.ReadAsStringAsync(cancellationToken); + if (!TryGetOpenAIRequest(bodyString, out var openAiRequest)) { Logger.LogDebug("Skipping non-OpenAI request"); - return Task.CompletedTask; + return null; } // Read the response body to get token usage - var response = e.Session.HttpClient.Response; - if (response.HasBody) + var httpResponse = args.Response; + if (httpResponse.Content != null) { - var responseBody = response.BodyString; + var responseBody = await httpResponse.Content.ReadAsStringAsync(cancellationToken); if (!string.IsNullOrEmpty(responseBody)) { try @@ -257,7 +253,7 @@ public override Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken _completionTokensRemaining = 0; } - Logger.LogRequest($"Consumed {promptTokens} prompt tokens and {completionTokens} completion tokens. Remaining - Prompt: {_promptTokensRemaining}, Completion: {_completionTokensRemaining}", MessageType.Processed, new(e.Session)); + Logger.LogRequest($"Consumed {promptTokens} prompt tokens and {completionTokens} completion tokens. Remaining - Prompt: {_promptTokensRemaining}, Completion: {_completionTokensRemaining}", MessageType.Processed, args.Request); } } catch (JsonException ex) @@ -267,9 +263,9 @@ public override Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken } } - Logger.LogTrace("Left {Name}", nameof(BeforeResponseAsync)); - return Task.CompletedTask; - } + Logger.LogTrace("Left {Name}", nameof(OnResponseAsync)); + return null; + }; private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) { @@ -310,7 +306,7 @@ private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) } } - private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) + private ThrottlingInfo ShouldThrottle(HttpRequestMessage request, string throttlingKey) { var throttleKeyForRequest = BuildThrottleKey(request); return new(throttleKeyForRequest == throttlingKey ? @@ -318,12 +314,8 @@ private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) Configuration.HeaderRetryAfter); } - private void ThrottleResponse(ProxyRequestArgs e) + private HttpResponseMessage BuildThrottleResponse(HttpRequestMessage request) { - var headers = new List(); - var body = string.Empty; - var request = e.Session.HttpClient.Request; - // Build standard OpenAI error response for token limit exceeded var openAiError = new { @@ -335,17 +327,23 @@ private void ThrottleResponse(ProxyRequestArgs e) code = "insufficient_quota" } }; - body = JsonSerializer.Serialize(openAiError, ProxyUtils.JsonSerializerOptions); + var body = JsonSerializer.Serialize(openAiError, ProxyUtils.JsonSerializerOptions); + + var response = new HttpResponseMessage(HttpStatusCode.TooManyRequests) + { + Content = new StringContent(body, Encoding.UTF8, "application/json") + }; + + _ = response.Headers.TryAddWithoutValidation(Configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture)); - headers.Add(new(Configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture))); - if (request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) + if (request.Headers.TryGetValues("Origin", out var _)) { - headers.Add(new("Access-Control-Allow-Origin", "*")); - headers.Add(new("Access-Control-Expose-Headers", Configuration.HeaderRetryAfter)); + _ = response.Headers.TryAddWithoutValidation("Access-Control-Allow-Origin", "*"); + _ = response.Headers.TryAddWithoutValidation("Access-Control-Expose-Headers", Configuration.HeaderRetryAfter); } - e.Session.GenericResponse(body, HttpStatusCode.TooManyRequests, [.. headers.Select(h => new HttpHeader(h.Name, h.Value))]); + return response; } - private static string BuildThrottleKey(Request r) => r.RequestUri.Host; + private static string BuildThrottleKey(HttpRequestMessage r) => r.RequestUri?.Host ?? string.Empty; } diff --git a/DevProxy.Plugins/Behavior/LatencyPlugin.cs b/DevProxy.Plugins/Behavior/LatencyPlugin.cs index 5a6df3a8..cd10e8b9 100644 --- a/DevProxy.Plugins/Behavior/LatencyPlugin.cs +++ b/DevProxy.Plugins/Behavior/LatencyPlugin.cs @@ -4,6 +4,7 @@ using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Proxy; +using DevProxy.Abstractions.Utils; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; @@ -32,22 +33,21 @@ public sealed class LatencyPlugin( public override string Name => nameof(LatencyPlugin); - public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return PluginResponse.Continue(); } var delay = _random.Next(Configuration.MinMs, Configuration.MaxMs); - Logger.LogRequest($"Delaying request for {delay}ms", MessageType.Chaos, new(e.Session)); + Logger.LogRequest($"Delaying request for {delay}ms", MessageType.Chaos, args.Request, args.RequestId); await Task.Delay(delay, cancellationToken); - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - } + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return PluginResponse.Continue(); + }; } diff --git a/DevProxy.Plugins/Behavior/RateLimitingPlugin.cs b/DevProxy.Plugins/Behavior/RateLimitingPlugin.cs index 44864c19..58875175 100644 --- a/DevProxy.Plugins/Behavior/RateLimitingPlugin.cs +++ b/DevProxy.Plugins/Behavior/RateLimitingPlugin.cs @@ -13,10 +13,9 @@ using Microsoft.Extensions.Logging; using System.Globalization; using System.Net; +using System.Text; using System.Text.Json; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Behavior; @@ -53,7 +52,8 @@ public sealed class RateLimitingPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BasePlugin( httpClient, logger, @@ -61,8 +61,7 @@ public sealed class RateLimitingPlugin( proxyConfiguration, pluginConfigurationSection) { - // initial values so that we know when we intercept the - // first request and can set the initial values + private readonly IProxyStorage _proxyStorage = proxyStorage; private int _resourcesRemaining = -1; private DateTime _resetTime = DateTime.MinValue; private RateLimitingCustomResponseLoader? _loader; @@ -83,23 +82,14 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell } } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - var session = e.Session; - var state = e.ResponseState; - if (state.HasBeenSet) - { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return Task.FromResult(PluginResponse.Continue()); } // set the initial values for the first request @@ -124,106 +114,122 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca if (_resourcesRemaining < 0) { _resourcesRemaining = 0; - var request = e.Session.HttpClient.Request; - Logger.LogRequest($"Exceeded resource limit when calling {request.Url}. Request will be throttled", MessageType.Failed, new(e.Session)); + Logger.LogRequest($"Exceeded resource limit when calling {args.Request.RequestUri}. Request will be throttled", MessageType.Failed, args.Request); if (Configuration.WhenLimitExceeded == RateLimitResponseWhenLimitExceeded.Throttle) { - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) + // Add throttling info to global data for RetryAfterPlugin coordination + if (!_proxyStorage.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) { value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + _proxyStorage.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); } var throttledRequests = value as List; throttledRequests?.Add(new( - BuildThrottleKey(request), + BuildThrottleKey(args.Request), ShouldThrottle, _resetTime )); - ThrottleResponse(e); - state.HasBeenSet = true; + + return Task.FromResult(PluginResponse.Respond(BuildThrottleResponse(args.Request))); } else { if (Configuration.CustomResponse is not null) { - var headersList = Configuration.CustomResponse.Headers is not null ? - Configuration.CustomResponse.Headers.Select(h => new HttpHeader(h.Name, h.Value)).ToList() : - []; - - var retryAfterHeader = headersList.FirstOrDefault(h => h.Name.Equals(Configuration.HeaderRetryAfter, StringComparison.OrdinalIgnoreCase)); - if (retryAfterHeader is not null && retryAfterHeader.Value == "@dynamic") - { - headersList.Add(new(Configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture))); - _ = headersList.Remove(retryAfterHeader); - } - - var headers = headersList.ToArray(); - - // allow custom throttling response var responseCode = (HttpStatusCode)(Configuration.CustomResponse.StatusCode ?? 200); + + // Add throttling info for TooManyRequests responses if (responseCode == HttpStatusCode.TooManyRequests) { - if (!e.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) + if (!_proxyStorage.GlobalData.TryGetValue(RetryAfterPlugin.ThrottledRequestsKey, out var value)) { value = new List(); - e.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); + _proxyStorage.GlobalData.Add(RetryAfterPlugin.ThrottledRequestsKey, value); } var throttledRequests = value as List; throttledRequests?.Add(new( - BuildThrottleKey(request), + BuildThrottleKey(args.Request), ShouldThrottle, _resetTime )); } - string body = Configuration.CustomResponse.Body is not null ? - JsonSerializer.Serialize(Configuration.CustomResponse.Body, ProxyUtils.JsonSerializerOptions) : - ""; - e.Session.GenericResponse(body, responseCode, headers); - state.HasBeenSet = true; + var response = new HttpResponseMessage(responseCode) + { + Content = new StringContent( + Configuration.CustomResponse.Body is not null ? + JsonSerializer.Serialize(Configuration.CustomResponse.Body, ProxyUtils.JsonSerializerOptions) : + string.Empty, + Encoding.UTF8, + "application/json") + }; + + // Add headers + if (Configuration.CustomResponse.Headers is not null) + { + foreach (var header in Configuration.CustomResponse.Headers) + { + var headerValue = header.Value; + if (header.Name.Equals(Configuration.HeaderRetryAfter, StringComparison.OrdinalIgnoreCase) && headerValue == "@dynamic") + { + headerValue = ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture); + } + _ = response.Headers.TryAddWithoutValidation(header.Name, headerValue); + } + } + + return Task.FromResult(PluginResponse.Respond(response)); } else { - Logger.LogRequest($"Custom behavior not set. {Configuration.CustomResponseFile} not found.", MessageType.Failed, new(e.Session)); + Logger.LogRequest($"Custom behavior not set. {Configuration.CustomResponseFile} not found.", MessageType.Failed, args.Request); + var response = new HttpResponseMessage(HttpStatusCode.InternalServerError) + { + Content = new StringContent("Custom response file not found.", Encoding.UTF8, "text/plain") + }; + return Task.FromResult(PluginResponse.Respond(response)); } } } else { - Logger.LogRequest($"Resources remaining: {_resourcesRemaining}", MessageType.Skipped, new(e.Session)); + Logger.LogRequest($"Resources remaining: {_resourcesRemaining}", MessageType.Skipped, args.Request); } - StoreRateLimitingHeaders(e); - return Task.CompletedTask; - } + StoreRateLimitingHeaders(args); + return Task.FromResult(PluginResponse.Continue()); + }; - public override Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + public override Func>? OnResponseAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeResponseAsync)); + Logger.LogTrace("{Method} called", nameof(OnResponseAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return Task.FromResult(null); } - if (e.ResponseState.HasBeenSet) + + // Add rate limiting headers to the response if we have them stored + var requestData = _proxyStorage.GetRequestData(args.RequestId); + if (requestData.TryGetValue(Name, out var pluginData) && + pluginData is List rateLimitingHeaders) { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + var response = args.Response; + foreach (var header in rateLimitingHeaders) + { + _ = response.Headers.TryAddWithoutValidation(header.Name, header.Value); + } } - UpdateProxyResponse(e, HttpStatusCode.OK); + Logger.LogTrace("Left {Name}", nameof(OnResponseAsync)); + return Task.FromResult(null); + }; - Logger.LogTrace("Left {Name}", nameof(BeforeResponseAsync)); - return Task.CompletedTask; - } - - private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) + private ThrottlingInfo ShouldThrottle(HttpRequestMessage request, string throttlingKey) { var throttleKeyForRequest = BuildThrottleKey(request); return new(throttleKeyForRequest == throttlingKey ? @@ -231,62 +237,54 @@ private ThrottlingInfo ShouldThrottle(Request request, string throttlingKey) Configuration.HeaderRetryAfter); } - private void ThrottleResponse(ProxyRequestArgs e) => UpdateProxyResponse(e, HttpStatusCode.TooManyRequests); - - private void UpdateProxyResponse(ProxyHttpEventArgsBase e, HttpStatusCode errorStatus) + private HttpResponseMessage BuildThrottleResponse(HttpRequestMessage request) { var headers = new List(); var body = string.Empty; - var request = e.Session.HttpClient.Request; - var response = e.Session.HttpClient.Response; // resources exceeded - if (errorStatus == HttpStatusCode.TooManyRequests) + if (ProxyUtils.IsGraphRequest(request)) { - if (ProxyUtils.IsGraphRequest(request)) - { - var requestId = Guid.NewGuid().ToString(); - var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); - headers.AddRange(ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate)); + var requestId = Guid.NewGuid().ToString(); + var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); + headers.AddRange(ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate)); - body = JsonSerializer.Serialize(new GraphErrorResponseBody( - new() + body = JsonSerializer.Serialize(new GraphErrorResponseBody( + new() + { + Code = new Regex("([A-Z])").Replace(HttpStatusCode.TooManyRequests.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), + Message = BuildApiErrorMessage(request), + InnerError = new() { - Code = new Regex("([A-Z])").Replace(errorStatus.ToString(), m => { return $" {m.Groups[1]}"; }).Trim(), - Message = BuildApiErrorMessage(request), - InnerError = new() - { - RequestId = requestId, - Date = requestDate - } - }), - ProxyUtils.JsonSerializerOptions - ); - } - - headers.Add(new(Configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture))); - if (request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) - { - headers.Add(new("Access-Control-Allow-Origin", "*")); - headers.Add(new("Access-Control-Expose-Headers", Configuration.HeaderRetryAfter)); - } + RequestId = requestId, + Date = requestDate + } + }), + ProxyUtils.JsonSerializerOptions + ); + } - e.Session.GenericResponse(body ?? string.Empty, errorStatus, [.. headers.Select(h => new HttpHeader(h.Name, h.Value))]); - return; + headers.Add(new(Configuration.HeaderRetryAfter, ((int)(_resetTime - DateTime.Now).TotalSeconds).ToString(CultureInfo.InvariantCulture))); + if (request.Headers.TryGetValues("Origin", out var _)) + { + headers.Add(new("Access-Control-Allow-Origin", "*")); + headers.Add(new("Access-Control-Expose-Headers", Configuration.HeaderRetryAfter)); } - if (e.SessionData.TryGetValue(Name, out var pluginData) && - pluginData is List rateLimitingHeaders) + var response = new HttpResponseMessage(HttpStatusCode.TooManyRequests) + { + Content = new StringContent(body ?? string.Empty, Encoding.UTF8, "application/json") + }; + + foreach (var header in headers) { - ProxyUtils.MergeHeaders(headers, rateLimitingHeaders); + _ = response.Headers.TryAddWithoutValidation(header.Name, header.Value); } - // add headers to the original API response, avoiding duplicates - headers.ForEach(h => e.Session.HttpClient.Response.Headers.RemoveHeader(h.Name)); - e.Session.HttpClient.Response.Headers.AddHeaders(headers.Select(h => new HttpHeader(h.Name, h.Value)).ToArray()); + return response; } - private void StoreRateLimitingHeaders(ProxyRequestArgs e) + private void StoreRateLimitingHeaders(RequestArguments args) { // add rate limiting headers if reached the threshold percentage if (_resourcesRemaining > Configuration.RateLimit - (Configuration.RateLimit * Configuration.WarningThresholdPercent / 100)) @@ -305,15 +303,15 @@ private void StoreRateLimitingHeaders(ProxyRequestArgs e) new(Configuration.HeaderReset, reset) ]); - ExposeRateLimitingForCors(headers, e); + ExposeRateLimitingForCors(headers, args.Request); - e.SessionData.Add(Name, headers); + var requestData = _proxyStorage.GetRequestData(args.RequestId); + requestData.Add(Name, headers); } - private void ExposeRateLimitingForCors(List headers, ProxyRequestArgs e) + private void ExposeRateLimitingForCors(List headers, HttpRequestMessage request) { - var request = e.Session.HttpClient.Request; - if (request.Headers.FirstOrDefault((h) => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase)) is null) + if (!request.Headers.TryGetValues("Origin", out var _)) { return; } @@ -322,9 +320,9 @@ private void ExposeRateLimitingForCors(List headers, ProxyRe headers.Add(new("Access-Control-Expose-Headers", $"{Configuration.HeaderLimit}, {Configuration.HeaderRemaining}, {Configuration.HeaderReset}, {Configuration.HeaderRetryAfter}")); } - private static string BuildApiErrorMessage(Request r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage()) : "")}"; + private static string BuildApiErrorMessage(HttpRequestMessage r) => $"Some error was generated by the proxy. {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage()) : "")}"; - private static string BuildThrottleKey(Request r) + private static string BuildThrottleKey(HttpRequestMessage r) { if (ProxyUtils.IsGraphRequest(r)) { @@ -332,7 +330,7 @@ private static string BuildThrottleKey(Request r) } else { - return r.RequestUri.Host; + return r.RequestUri?.Host ?? string.Empty; } } } diff --git a/DevProxy.Plugins/Behavior/RetryAfterPlugin.cs b/DevProxy.Plugins/Behavior/RetryAfterPlugin.cs index 25e58efc..52f146c9 100644 --- a/DevProxy.Plugins/Behavior/RetryAfterPlugin.cs +++ b/DevProxy.Plugins/Behavior/RetryAfterPlugin.cs @@ -11,62 +11,60 @@ using Microsoft.Extensions.Logging; using System.Globalization; using System.Net; +using System.Text; using System.Text.Json; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Behavior; public sealed class RetryAfterPlugin( ILogger logger, - ISet urlsToWatch) : BasePlugin(logger, urlsToWatch) + ISet urlsToWatch, + IProxyStorage proxyStorage) : BasePlugin(logger, urlsToWatch) { + private readonly IProxyStorage _proxyStorage = proxyStorage; public static readonly string ThrottledRequestsKey = "ThrottledRequests"; public override string Name => nameof(RetryAfterPlugin); - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return Task.FromResult(PluginResponse.Continue()); } - if (e.ResponseState.HasBeenSet) + + if (args.Request.Method == HttpMethod.Options) { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, args.Request); + return Task.FromResult(PluginResponse.Continue()); } - if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + + var throttleResponse = CheckIfThrottled(args.Request); + if (throttleResponse != null) { - Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + return Task.FromResult(PluginResponse.Respond(throttleResponse)); } - ThrottleIfNecessary(e); - - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return Task.FromResult(PluginResponse.Continue()); + }; - private void ThrottleIfNecessary(ProxyRequestArgs e) + private HttpResponseMessage? CheckIfThrottled(HttpRequestMessage request) { - var request = e.Session.HttpClient.Request; - if (!e.GlobalData.TryGetValue(ThrottledRequestsKey, out var value)) + if (!_proxyStorage.GlobalData.TryGetValue(ThrottledRequestsKey, out var value)) { - Logger.LogRequest("Request not throttled", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Request not throttled", MessageType.Skipped, request); + return null; } if (value is not List throttledRequests) { - Logger.LogRequest("Request not throttled", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Request not throttled", MessageType.Skipped, request); + return null; } var expiredThrottlers = throttledRequests.Where(t => t.ResetTime < DateTime.Now).ToArray(); @@ -77,8 +75,8 @@ private void ThrottleIfNecessary(ProxyRequestArgs e) if (throttledRequests.Count == 0) { - Logger.LogRequest("Request not throttled", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Request not throttled", MessageType.Skipped, request); + return null; } foreach (var throttler in throttledRequests) @@ -86,23 +84,22 @@ private void ThrottleIfNecessary(ProxyRequestArgs e) var throttleInfo = throttler.ShouldThrottle(request, throttler.ThrottlingKey); if (throttleInfo.ThrottleForSeconds > 0) { - var message = $"Calling {request.Url} before waiting for the Retry-After period. Request will be throttled. Throttling on {throttler.ThrottlingKey}."; - Logger.LogRequest(message, MessageType.Failed, new(e.Session)); + var message = $"Calling {request.RequestUri} before waiting for the Retry-After period. Request will be throttled. Throttling on {throttler.ThrottlingKey}."; + Logger.LogRequest(message, MessageType.Failed, request); throttler.ResetTime = DateTime.Now.AddSeconds(throttleInfo.ThrottleForSeconds); - UpdateProxyResponse(e, throttleInfo, string.Join(' ', message)); - return; + return BuildThrottleResponse(request, throttleInfo, string.Join(' ', message)); } } - Logger.LogRequest("Request not throttled", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Request not throttled", MessageType.Skipped, request); + return null; } - private static void UpdateProxyResponse(ProxyRequestArgs e, ThrottlingInfo throttlingInfo, string message) + private static HttpResponseMessage BuildThrottleResponse(HttpRequestMessage request, ThrottlingInfo throttlingInfo, string message) { var headers = new List(); var body = string.Empty; - var request = e.Session.HttpClient.Request; // override the response body and headers for the error response if (ProxyUtils.IsGraphRequest(request)) @@ -128,7 +125,7 @@ private static void UpdateProxyResponse(ProxyRequestArgs e, ThrottlingInfo throt else { // ProxyUtils.BuildGraphResponseHeaders already includes CORS headers - if (request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) + if (request.Headers.TryGetValues("Origin", out var _)) { headers.Add(new("Access-Control-Allow-Origin", "*")); headers.Add(new("Access-Control-Expose-Headers", throttlingInfo.RetryAfterHeaderName)); @@ -137,9 +134,18 @@ private static void UpdateProxyResponse(ProxyRequestArgs e, ThrottlingInfo throt headers.Add(new(throttlingInfo.RetryAfterHeaderName, throttlingInfo.ThrottleForSeconds.ToString(CultureInfo.InvariantCulture))); - e.Session.GenericResponse(body ?? string.Empty, HttpStatusCode.TooManyRequests, headers.Select(h => new HttpHeader(h.Name, h.Value))); - e.ResponseState.HasBeenSet = true; + var response = new HttpResponseMessage(HttpStatusCode.TooManyRequests) + { + Content = new StringContent(body ?? string.Empty, Encoding.UTF8, "application/json") + }; + + foreach (var header in headers) + { + _ = response.Headers.TryAddWithoutValidation(header.Name, header.Value); + } + + return response; } - private static string BuildApiErrorMessage(Request r, string message) => $"{message} {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage()) : "")}"; + private static string BuildApiErrorMessage(HttpRequestMessage r, string message) => $"{message} {(ProxyUtils.IsGraphRequest(r) ? ProxyUtils.IsSdkRequest(r) ? "" : string.Join(' ', MessageUtils.BuildUseSdkForErrorsMessage()) : "")}"; } diff --git a/DevProxy.Plugins/DevProxy.Plugins.csproj b/DevProxy.Plugins/DevProxy.Plugins.csproj index 165b095a..26c4c776 100644 --- a/DevProxy.Plugins/DevProxy.Plugins.csproj +++ b/DevProxy.Plugins/DevProxy.Plugins.csproj @@ -60,10 +60,6 @@ false runtime - - false - runtime - diff --git a/DevProxy.Plugins/Extensions/ApiCenterExtensions.cs b/DevProxy.Plugins/Extensions/ApiCenterExtensions.cs index a387970d..39af41f3 100644 --- a/DevProxy.Plugins/Extensions/ApiCenterExtensions.cs +++ b/DevProxy.Plugins/Extensions/ApiCenterExtensions.cs @@ -189,17 +189,19 @@ internal static IEnumerable GetUrls(this Api api) return apiVersion; } - // check headers - Debug.Assert(request.Context is not null); - var header = request.Context.Session.HttpClient.Request.Headers.FirstOrDefault( - h => - (!string.IsNullOrEmpty(apiVersion.Name) && h.Value.Contains(apiVersion.Name, StringComparison.OrdinalIgnoreCase)) || - (!string.IsNullOrEmpty(apiVersion.Properties?.Title) && h.Value.Contains(apiVersion.Properties.Title, StringComparison.OrdinalIgnoreCase)) - ); - if (header is not null) + // check headers - use the new Request property instead of Context.Session + if (request.Request?.Headers != null) { - logger.LogDebug("Version {Version} found in header {Header}", $"{apiVersion.Name}/{apiVersion.Properties?.Title}", header.Name); - return apiVersion; + var header = request.Request.Headers.FirstOrDefault( + h => + (!string.IsNullOrEmpty(apiVersion.Name) && string.Join(", ", h.Value).Contains(apiVersion.Name, StringComparison.OrdinalIgnoreCase)) || + (!string.IsNullOrEmpty(apiVersion.Properties?.Title) && string.Join(", ", h.Value).Contains(apiVersion.Properties.Title, StringComparison.OrdinalIgnoreCase)) + ); + if (header.Key is not null) + { + logger.LogDebug("Version {Version} found in header {Header}", $"{apiVersion.Name}/{apiVersion.Properties?.Title}", header.Key); + return apiVersion; + } } } diff --git a/DevProxy.Plugins/Extensions/OpenApiDocumentExtensions.cs b/DevProxy.Plugins/Extensions/OpenApiDocumentExtensions.cs index 9458db4d..a2e9bf6e 100644 --- a/DevProxy.Plugins/Extensions/OpenApiDocumentExtensions.cs +++ b/DevProxy.Plugins/Extensions/OpenApiDocumentExtensions.cs @@ -32,7 +32,7 @@ public static ApiPermissionsInfo CheckMinimalPermissions(this OpenApiDocument op logger.LogDebug("Checking request {Request}...", methodAndUrl); var (method, url) = (methodAndUrlChunks[0].ToUpperInvariant(), methodAndUrlChunks[1]); - var scopesFromTheToken = MinimalPermissionsUtils.GetScopesFromToken(request.Context?.Session.HttpClient.Request.Headers.First(h => h.Name.Equals("authorization", StringComparison.OrdinalIgnoreCase)).Value, logger); + var scopesFromTheToken = MinimalPermissionsUtils.GetScopesFromToken(request.Request!.Headers.Authorization?.Parameter, logger); if (scopesFromTheToken.Length != 0) { tokenPermissions.AddRange(scopesFromTheToken); diff --git a/DevProxy.Plugins/Generation/ApiCenterOnboardingPlugin.cs b/DevProxy.Plugins/Generation/ApiCenterOnboardingPlugin.cs index 090cb88f..aa3597af 100644 --- a/DevProxy.Plugins/Generation/ApiCenterOnboardingPlugin.cs +++ b/DevProxy.Plugins/Generation/ApiCenterOnboardingPlugin.cs @@ -29,13 +29,15 @@ public sealed class ApiCenterOnboardingPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { private ApiCenterClient? _apiCenterClient; private Api[]? _apis; @@ -82,7 +84,9 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell Logger.LogDebug("Plugin {Plugin} auth confirmed...", Name); } - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -181,7 +185,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation { ExistingApis = [.. existingApis], NewApis = [] - }, e); + }); return; } @@ -196,7 +200,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation Method = a.method, Url = a.url })] - }, e); + }); var apisPerSchemeAndHost = newApis.GroupBy(x => { @@ -218,7 +222,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation return; } - var generatedOpenApiSpecs = e.GlobalData.TryGetValue(OpenApiSpecGeneratorPlugin.GeneratedOpenApiSpecsKey, out var specs) ? specs as Dictionary : []; + var generatedOpenApiSpecs = ProxyStorage.GlobalData.TryGetValue(OpenApiSpecGeneratorPlugin.GeneratedOpenApiSpecsKey, out var specs) ? specs as Dictionary : []; await CreateApisInApiCenterAsync(apisPerSchemeAndHost, generatedOpenApiSpecs!, cancellationToken); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); diff --git a/DevProxy.Plugins/Generation/HttpFileGeneratorPlugin.cs b/DevProxy.Plugins/Generation/HttpFileGeneratorPlugin.cs index f37decd8..b8354545 100644 --- a/DevProxy.Plugins/Generation/HttpFileGeneratorPlugin.cs +++ b/DevProxy.Plugins/Generation/HttpFileGeneratorPlugin.cs @@ -83,13 +83,15 @@ public sealed class HttpFileGeneratorPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { public static readonly string GeneratedHttpFilesKey = "GeneratedHttpFiles"; @@ -98,9 +100,10 @@ public sealed class HttpFileGeneratorPlugin( public override string Name => nameof(HttpFileGeneratorPlugin); - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); ArgumentNullException.ThrowIfNull(e); @@ -122,11 +125,11 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation Logger.LogInformation("Created HTTP file {FileName}", fileName); var generatedHttpFiles = new[] { fileName }; - StoreReport(new HttpFileGeneratorPluginReport(generatedHttpFiles), e); + StoreReport(new HttpFileGeneratorPluginReport(generatedHttpFiles)); // store the generated HTTP files in the global data // for use by other plugins - e.GlobalData[GeneratedHttpFilesKey] = generatedHttpFiles; + ProxyStorage.GlobalData[GeneratedHttpFilesKey] = generatedHttpFiles; Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } @@ -135,22 +138,22 @@ private async Task GetHttpRequestsAsync(IEnumerable reques { var httpFile = new HttpFile(); - foreach (var request in requestLogs) + foreach (var request in requestLogs.Where(l => + l.MessageType == MessageType.InterceptedResponse + && l.Request is not null + && l.Response is not null)) { cancellationToken.ThrowIfCancellationRequested(); - if (request.MessageType != MessageType.InterceptedResponse || - request.Context is null || - request.Context.Session is null || - !ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Request!.RequestUri!.AbsoluteUri)) { continue; } if (!Configuration.IncludeOptionsRequests && - string.Equals(request.Context.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + request.Request.Method == HttpMethod.Options) { - Logger.LogDebug("Skipping OPTIONS request {Url}...", request.Context.Session.HttpClient.Request.RequestUri); + Logger.LogDebug("Skipping OPTIONS request {Url}...", request.Request.RequestUri); continue; } @@ -162,8 +165,8 @@ request.Context.Session is null || { Method = methodAndUrl[0], Url = methodAndUrl[1], - Body = request.Context.Session.HttpClient.Request.HasBody ? await request.Context.Session.GetRequestBodyAsString(cancellationToken) : null, - Headers = [.. request.Context.Session.HttpClient.Request.Headers.Select(h => new HttpFileRequestHeader { Name = h.Name, Value = h.Value })] + Body = request.Request.Content is not null ? await request.Request.Content.ReadAsStringAsync(cancellationToken) : null, + Headers = [.. request.Request.Headers.Select(h => new HttpFileRequestHeader { Name = h.Key, Value = string.Join(',', h.Value) })] }); } diff --git a/DevProxy.Plugins/Generation/MockGeneratorPlugin.cs b/DevProxy.Plugins/Generation/MockGeneratorPlugin.cs index 957e6564..eb6b1ab3 100644 --- a/DevProxy.Plugins/Generation/MockGeneratorPlugin.cs +++ b/DevProxy.Plugins/Generation/MockGeneratorPlugin.cs @@ -9,18 +9,20 @@ using DevProxy.Plugins.Mocking; using DevProxy.Plugins.Utils; using Microsoft.Extensions.Logging; +using System.Net.Http.Json; using System.Text.Json; -using Titanium.Web.Proxy.EventArguments; namespace DevProxy.Plugins.Generation; public sealed class MockGeneratorPlugin( ILogger logger, - ISet urlsToWatch) : BaseReportingPlugin(logger, urlsToWatch) + ISet urlsToWatch, + IProxyStorage proxyStorage) : BaseReportingPlugin(logger, urlsToWatch, proxyStorage) { public override string Name => nameof(MockGeneratorPlugin); - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -42,9 +44,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation cancellationToken.ThrowIfCancellationRequested(); if (request.MessageType != MessageType.InterceptedResponse || - request.Context is null || - request.Context.Session is null || - !ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri)) + !ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Request!.RequestUri!.AbsoluteUri)) { continue; } @@ -53,10 +53,16 @@ request.Context.Session is null || Logger.LogDebug("Processing request {MethodAndUrlString}...", methodAndUrlString); var (method, url) = GetMethodAndUrl(methodAndUrlString); - var response = request.Context.Session.HttpClient.Response; + var response = request.Response; + + if (response is null) + { + Logger.LogDebug("No response found for request {MethodAndUrlString}. Skipping", methodAndUrlString); + continue; + } var newHeaders = new List(); - newHeaders.AddRange(response.Headers.Select(h => new MockResponseHeader(h.Name, h.Value))); + newHeaders.AddRange(response.Headers.Select(h => new MockResponseHeader(h.Key, string.Join(';', h.Value)))); var mock = new MockResponse { Request = new() @@ -66,9 +72,9 @@ request.Context.Session is null || }, Response = new() { - StatusCode = response.StatusCode, + StatusCode = (int)response.StatusCode, Headers = newHeaders, - Body = await GetResponseBodyAsync(request.Context.Session, cancellationToken) + Body = await GetResponseBodyAsync(request.Response!, cancellationToken) } }; // skip mock if it's 200 but has no body @@ -97,7 +103,7 @@ request.Context.Session is null || Logger.LogInformation("Created mock file {FileName} with {MocksCount} mocks", fileName, mocks.Count); - StoreReport(fileName, e); + StoreReport(fileName); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } @@ -108,28 +114,24 @@ request.Context.Session is null || /// /// Request session /// Response body or @filename for binary responses - private async Task GetResponseBodyAsync(SessionEventArgs session, CancellationToken cancellationToken) + private async Task GetResponseBodyAsync(HttpResponseMessage response, CancellationToken cancellationToken) { Logger.LogDebug("Getting response body..."); - var response = session.HttpClient.Response; - if (response.ContentType is null || !response.HasBody) + if (response.Content is null) { Logger.LogDebug("Response has no content-type set or has no body. Skipping"); return null; } - if (response.ContentType.Contains("application/json", StringComparison.OrdinalIgnoreCase)) + if (response.Content.Headers.ContentType?.MediaType?.Contains("application/json", StringComparison.OrdinalIgnoreCase) ?? false) { Logger.LogDebug("Response is JSON"); try { Logger.LogDebug("Reading response body as string..."); - var body = response.IsBodyRead ? response.BodyString : await session.GetResponseBodyAsString(cancellationToken); - Logger.LogDebug("Body: {Body}", body); - Logger.LogDebug("Deserializing response body..."); - return JsonSerializer.Deserialize(body, ProxyUtils.JsonSerializerOptions); + return await response.Content.ReadFromJsonAsync(ProxyUtils.JsonSerializerOptions, cancellationToken); } catch (Exception ex) { @@ -144,7 +146,7 @@ request.Context.Session is null || { var filename = $"response-{Guid.NewGuid()}.bin"; Logger.LogDebug("Reading response body as bytes..."); - var body = await session.GetResponseBody(cancellationToken); + var body = await response.Content.ReadAsByteArrayAsync(cancellationToken); Logger.LogDebug("Writing response body to {Filename}...", filename); await File.WriteAllBytesAsync(filename, body, cancellationToken); return $"@{filename}"; diff --git a/DevProxy.Plugins/Generation/OpenApiSpecGeneratorPlugin.cs b/DevProxy.Plugins/Generation/OpenApiSpecGeneratorPlugin.cs index 255b2a5d..878538bf 100644 --- a/DevProxy.Plugins/Generation/OpenApiSpecGeneratorPlugin.cs +++ b/DevProxy.Plugins/Generation/OpenApiSpecGeneratorPlugin.cs @@ -19,8 +19,6 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Web; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Http; namespace DevProxy.Plugins.Generation; @@ -66,19 +64,22 @@ public sealed class OpenApiSpecGeneratorPlugin( ISet urlsToWatch, ILanguageModelClient languageModelClient, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { public static readonly string GeneratedOpenApiSpecsKey = "GeneratedOpenApiSpecs"; public override string Name => nameof(OpenApiSpecGeneratorPlugin); - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -99,17 +100,15 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation cancellationToken.ThrowIfCancellationRequested(); if (request.MessageType != MessageType.InterceptedResponse || - request.Context is null || - request.Context.Session is null || - !ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri)) + !ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Request!.RequestUri!.AbsoluteUri)) { continue; } if (!Configuration.IncludeOptionsRequests && - string.Equals(request.Context.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + request.Request!.Method == HttpMethod.Options) { - Logger.LogDebug("Skipping OPTIONS request {Url}...", request.Context.Session.HttpClient.Request.RequestUri); + Logger.LogDebug("Skipping OPTIONS request {Url}...", request.Request.RequestUri); continue; } @@ -118,22 +117,22 @@ request.Context.Session is null || try { - var pathItem = GetOpenApiPathItem(request.Context.Session); - var parametrizedPath = ParametrizePath(pathItem, request.Context.Session.HttpClient.Request.RequestUri); + var pathItem = await GetOpenApiPathItem(request.Request, request.Response); + var parametrizedPath = ParametrizePath(pathItem, request.Request.RequestUri!); var operationInfo = pathItem.Operations.First(); operationInfo.Value.OperationId = await GetOperationIdAsync( operationInfo.Key.ToString(), - request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority), + request.Request.RequestUri.GetLeftPart(UriPartial.Authority), parametrizedPath, cancellationToken ); operationInfo.Value.Description = await GetOperationDescriptionAsync( operationInfo.Key.ToString(), - request.Context.Session.HttpClient.Request.RequestUri.GetLeftPart(UriPartial.Authority), + request.Request.RequestUri.GetLeftPart(UriPartial.Authority), parametrizedPath, cancellationToken ); - AddOrMergePathItem(openApiDocs, pathItem, request.Context.Session.HttpClient.Request.RequestUri, parametrizedPath); + AddOrMergePathItem(openApiDocs, pathItem, request.Request.RequestUri, parametrizedPath); } catch (Exception ex) { @@ -176,11 +175,11 @@ request.Context.Session is null || { ServerUrl = kvp.Key, FileName = kvp.Value - })), e); + }))); // store the generated OpenAPI specs in the global data // for use by other plugins - e.GlobalData[GeneratedOpenApiSpecsKey] = generatedOpenApiSpecs; + ProxyStorage.GlobalData[GeneratedOpenApiSpecsKey] = generatedOpenApiSpecs; Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } @@ -206,17 +205,15 @@ private async Task GetOperationDescriptionAsync(string method, string se /** * Creates an OpenAPI PathItem from an intercepted request and response pair. - * @param session The intercepted session. + * @param httpRequest The HTTP request message. + * @param httpResponse The HTTP response message. */ - private OpenApiPathItem GetOpenApiPathItem(SessionEventArgs session) + private async Task GetOpenApiPathItem(HttpRequestMessage httpRequest, HttpResponseMessage? httpResponse) { - var request = session.HttpClient.Request; - var response = session.HttpClient.Response; - - var resource = GetLastNonTokenSegment(request.RequestUri.Segments); + var resource = GetLastNonTokenSegment(httpRequest.RequestUri!.Segments); var path = new OpenApiPathItem(); - var method = request.Method?.ToUpperInvariant() switch + var method = httpRequest.Method.Method.ToUpperInvariant() switch { "DELETE" => OperationType.Delete, "GET" => OperationType.Get, @@ -226,7 +223,7 @@ private OpenApiPathItem GetOpenApiPathItem(SessionEventArgs session) "POST" => OperationType.Post, "PUT" => OperationType.Put, "TRACE" => OperationType.Trace, - _ => throw new NotSupportedException($"Method {request.Method} is not supported") + _ => throw new NotSupportedException($"Method {httpRequest.Method} is not supported") }; var operation = new OpenApiOperation { @@ -235,50 +232,51 @@ private OpenApiPathItem GetOpenApiPathItem(SessionEventArgs session) // will be replaced later after the path has been parametrized OperationId = $"{method}.{resource}" }; - SetParametersFromQueryString(operation, HttpUtility.ParseQueryString(request.RequestUri.Query)); - SetParametersFromRequestHeaders(operation, request.Headers); - SetRequestBody(operation, request); - SetResponseFromSession(operation, response); + SetParametersFromQueryString(operation, HttpUtility.ParseQueryString(httpRequest.RequestUri.Query)); + SetParametersFromRequestHeaders(operation, httpRequest.Headers); + await SetRequestBody(operation, httpRequest); + await SetResponseFromHttpResponseMessage(operation, httpResponse); path.Operations.Add(method, operation); return path; } - private void SetRequestBody(OpenApiOperation operation, Request request) + private async Task SetRequestBody(OpenApiOperation operation, HttpRequestMessage httpRequest) { - if (!request.HasBody) + if (httpRequest.Content is null) { Logger.LogDebug(" Request has no body"); return; } - if (request.ContentType is null) + var contentType = httpRequest.Content.Headers.ContentType?.MediaType; + if (contentType is null) { Logger.LogDebug(" Request has no content type"); return; } Logger.LogDebug(" Processing request body..."); + var bodyString = await httpRequest.Content.ReadAsStringAsync(); operation.RequestBody = new() { Content = new Dictionary { { - GetMediaType(request.ContentType), + GetMediaType(contentType), new() { - Schema = GetSchemaFromBody(GetMediaType(request.ContentType), request.BodyString) + Schema = GetSchemaFromBody(GetMediaType(contentType), bodyString) } } } }; } - private void SetParametersFromRequestHeaders(OpenApiOperation operation, HeaderCollection headers) + private void SetParametersFromRequestHeaders(OpenApiOperation operation, System.Net.Http.Headers.HttpRequestHeaders headers) { - if (headers is null || - !headers.Any()) + if (headers is null || !headers.Any()) { Logger.LogDebug(" Request has no headers"); return; @@ -287,27 +285,27 @@ private void SetParametersFromRequestHeaders(OpenApiOperation operation, HeaderC Logger.LogDebug(" Processing request headers..."); foreach (var header in headers) { - var lowerCaseHeaderName = header.Name.ToLowerInvariant(); + var lowerCaseHeaderName = header.Key.ToLowerInvariant(); if (Http.StandardHeaders.Contains(lowerCaseHeaderName)) { - Logger.LogDebug(" Skipping standard header {HeaderName}", header.Name); + Logger.LogDebug(" Skipping standard header {HeaderName}", header.Key); continue; } if (Http.AuthHeaders.Contains(lowerCaseHeaderName)) { - Logger.LogDebug(" Skipping auth header {HeaderName}", header.Name); + Logger.LogDebug(" Skipping auth header {HeaderName}", header.Key); continue; } operation.Parameters.Add(new() { - Name = header.Name, + Name = header.Key, In = ParameterLocation.Header, Required = false, Schema = new() { Type = "string" } }); - Logger.LogDebug(" Added header {HeaderName}", header.Name); + Logger.LogDebug(" Added header {HeaderName}", header.Key); } } @@ -352,9 +350,9 @@ private static void SetParameterDefault(OpenApiParameter parameter, object? valu parameter.Schema.Default = new OpenApiString(value.ToString()); } - private void SetResponseFromSession(OpenApiOperation operation, Response response) + private async Task SetResponseFromHttpResponseMessage(OpenApiOperation operation, HttpResponseMessage? httpResponse) { - if (response is null) + if (httpResponse is null) { Logger.LogDebug(" No response to process"); return; @@ -364,16 +362,20 @@ private void SetResponseFromSession(OpenApiOperation operation, Response respons var openApiResponse = new OpenApiResponse { - Description = response.StatusDescription + Description = httpResponse.ReasonPhrase ?? httpResponse.StatusCode.ToString() }; - var responseCode = response.StatusCode.ToString(CultureInfo.InvariantCulture); - if (response.HasBody) + var responseCode = ((int)httpResponse.StatusCode).ToString(CultureInfo.InvariantCulture); + + if (httpResponse.Content is not null) { Logger.LogDebug(" Response has body"); - openApiResponse.Content.Add(GetMediaType(response.ContentType), new() + var contentType = httpResponse.Content.Headers.ContentType?.MediaType; + var responseBody = await httpResponse.Content.ReadAsStringAsync(); + + openApiResponse.Content.Add(GetMediaType(contentType), new() { - Schema = GetSchemaFromBody(GetMediaType(response.ContentType), response.BodyString) + Schema = GetSchemaFromBody(GetMediaType(contentType), responseBody) }); } else @@ -381,36 +383,36 @@ private void SetResponseFromSession(OpenApiOperation operation, Response respons Logger.LogDebug(" Response doesn't have body"); } - if (response.Headers is not null && response.Headers.Any()) + if (httpResponse.Headers is not null && httpResponse.Headers.Any()) { Logger.LogDebug(" Response has headers"); - foreach (var header in response.Headers) + foreach (var header in httpResponse.Headers) { - var lowerCaseHeaderName = header.Name.ToLowerInvariant(); + var lowerCaseHeaderName = header.Key.ToLowerInvariant(); if (Http.StandardHeaders.Contains(lowerCaseHeaderName)) { - Logger.LogDebug(" Skipping standard header {HeaderName}", header.Name); + Logger.LogDebug(" Skipping standard header {HeaderName}", header.Key); continue; } if (Http.AuthHeaders.Contains(lowerCaseHeaderName)) { - Logger.LogDebug(" Skipping auth header {HeaderName}", header.Name); + Logger.LogDebug(" Skipping auth header {HeaderName}", header.Key); continue; } - if (openApiResponse.Headers.ContainsKey(header.Name)) + if (openApiResponse.Headers.ContainsKey(header.Key)) { - Logger.LogDebug(" Header {HeaderName} already exists in response", header.Name); + Logger.LogDebug(" Header {HeaderName} already exists in response", header.Key); continue; } - openApiResponse.Headers.Add(header.Name, new() + openApiResponse.Headers.Add(header.Key, new() { Schema = new() { Type = "string" } }); - Logger.LogDebug(" Added header {HeaderName}", header.Name); + Logger.LogDebug(" Added header {HeaderName}", header.Key); } } else @@ -469,6 +471,7 @@ private void AddOrMergePathItem(List openApiDocs, OpenApiPathIt new() { Url = serverUrl } ], Paths = [], + Extensions = new Dictionary { { "x-ms-generated-by", new GeneratedByOpenApiExtension() } @@ -670,6 +673,7 @@ private static string ParametrizePath(OpenApiPathItem pathItem, Uri requestUri) for (var i = 0; i < segments.Length; i++) { var segment = requestUri.Segments[i].Trim('/'); + if (string.IsNullOrEmpty(segment)) { continue; diff --git a/DevProxy.Plugins/Generation/TypeSpecGeneratorPlugin.cs b/DevProxy.Plugins/Generation/TypeSpecGeneratorPlugin.cs index b95cee82..e7929c53 100644 --- a/DevProxy.Plugins/Generation/TypeSpecGeneratorPlugin.cs +++ b/DevProxy.Plugins/Generation/TypeSpecGeneratorPlugin.cs @@ -14,7 +14,6 @@ using System.Text.Json; using System.Text.RegularExpressions; using System.Web; -using Titanium.Web.Proxy.Http; namespace DevProxy.Plugins.Generation; @@ -42,19 +41,22 @@ public sealed class TypeSpecGeneratorPlugin( ISet urlsToWatch, ILanguageModelClient languageModelClient, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { public static readonly string GeneratedTypeSpecFilesKey = "GeneratedTypeSpecFiles"; public override string Name => nameof(TypeSpecGeneratorPlugin); - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -70,18 +72,18 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation var typeSpecFiles = new List(); - foreach (var request in e.RequestLogs) + foreach (var request in e.RequestLogs.Where(l => + l.MessageType == MessageType.InterceptedRequest + && l.Request is not null + && l.Response is not null + && l.Request.Method != HttpMethod.Options)) { cancellationToken.ThrowIfCancellationRequested(); - if (request.MessageType != MessageType.InterceptedResponse || - request.Context is null || - request.Context.Session is null || + if ( request.Url is null || request.Method is null || - // TypeSpec does not support OPTIONS requests - string.Equals(request.Context.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase) || - !ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri)) + !ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Request!.RequestUri!.AbsoluteUri)) { continue; } @@ -127,11 +129,11 @@ request.Method is null || { ServerUrl = kvp.Key, FileName = kvp.Value - })), e); + }))); // store the generated TypeSpec files in the global data // for use by other plugins - e.GlobalData[GeneratedTypeSpecFilesKey] = generatedTypeSpecFiles; + ProxyStorage.GlobalData[GeneratedTypeSpecFilesKey] = generatedTypeSpecFiles; Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } @@ -140,18 +142,19 @@ private async Task GetOperationAsync(RequestLog request, TypeSpecFile { Logger.LogTrace("Entered {Name}", nameof(GetOperationAsync)); - Debug.Assert(request.Context is not null, "request.Context is null"); + Debug.Assert(request.Request is not null, "request.Request is null"); + Debug.Assert(request.Response is not null, "request.Response is null"); Debug.Assert(request.Method is not null, "request.Method is null"); Debug.Assert(request.Url is not null, "request.Url is null"); var url = new Uri(request.Url); - var httpRequest = request.Context.Session.HttpClient.Request; - var httpResponse = request.Context.Session.HttpClient.Response; + var httpRequest = request.Request; + var httpResponse = request.Response; - var (route, parameters) = await GetRouteAndParametersAsync(url); + var (route, parameters) = GetRouteAndParametersAsync(url); var op = new Operation { - Name = await GetOperationNameAsync(request.Method, url), + Name = GetOperationName(request.Method, url), Description = await GetOperationDescriptionAsync(request.Method, url), Method = Enum.Parse(request.Method, true), Route = route, @@ -170,13 +173,13 @@ private async Task GetOperationAsync(RequestLog request, TypeSpecFile return op; } - private void ProcessAuth(Request httpRequest, TypeSpecFile doc, Operation op) + private void ProcessAuth(HttpRequestMessage httpRequest, TypeSpecFile doc, Operation op) { Logger.LogTrace("Entered {Name}", nameof(ProcessAuth)); var authHeaders = httpRequest.Headers - .Where(h => Http.AuthHeaders.Contains(h.Name.ToLowerInvariant())) - .Select(h => (h.Name, h.Value)); + .Where(h => Http.AuthHeaders.Contains(h.Key.ToLowerInvariant())) + .Select(h => (h.Key, string.Join(", ", h.Value))); foreach (var (name, value) in authHeaders) { @@ -197,7 +200,7 @@ private void ProcessAuth(Request httpRequest, TypeSpecFile doc, Operation op) return; } - var query = HttpUtility.ParseQueryString(httpRequest.RequestUri.Query); + var query = HttpUtility.ParseQueryString(httpRequest.RequestUri?.Query ?? string.Empty); var authQueryParam = query.AllKeys .FirstOrDefault(k => k is not null && Http.AuthHeaders.Contains(k.ToLowerInvariant())); if (authQueryParam is not null) @@ -321,17 +324,18 @@ private bool IsJwtToken(string bearerToken, out JwtSecurityToken jwtToken) return false; } - private async Task ProcessRequestBodyAsync(Request httpRequest, TypeSpecFile doc, Operation op, string lastSegment) + private async Task ProcessRequestBodyAsync(HttpRequestMessage httpRequest, TypeSpecFile doc, Operation op, string lastSegment) { Logger.LogTrace("Entered {Name}", nameof(ProcessRequestBodyAsync)); - if (!httpRequest.HasBody) + if (httpRequest.Content is null) { Logger.LogDebug("Request has no body, skipping..."); return; } - var models = await GetModelsFromStringAsync(httpRequest.BodyString, lastSegment.ToPascalCase()); + var requestBody = await httpRequest.Content.ReadAsStringAsync(); + var models = await GetModelsFromStringAsync(requestBody, lastSegment.ToPascalCase()); if (models.Length > 0) { foreach (var model in models) @@ -342,7 +346,7 @@ private async Task ProcessRequestBodyAsync(Request httpRequest, TypeSpecFile doc var rootModel = models.Last(); op.Parameters.Add(new() { - Name = await GetParameterNameAsync(rootModel), + Name = GetParameterNameAsync(rootModel), Value = rootModel.Name, In = ParameterLocation.Body }); @@ -351,22 +355,22 @@ private async Task ProcessRequestBodyAsync(Request httpRequest, TypeSpecFile doc Logger.LogTrace("Left {Name}", nameof(ProcessRequestBodyAsync)); } - private void ProcessRequestHeaders(Request httpRequest, Operation op) + private void ProcessRequestHeaders(HttpRequestMessage httpRequest, Operation op) { Logger.LogTrace("Entered {Name}", nameof(ProcessRequestHeaders)); foreach (var header in httpRequest.Headers) { - if (Http.StandardHeaders.Contains(header.Name.ToLowerInvariant()) || - Http.AuthHeaders.Contains(header.Name.ToLowerInvariant())) + if (Http.StandardHeaders.Contains(header.Key.ToLowerInvariant()) || + Http.AuthHeaders.Contains(header.Key.ToLowerInvariant())) { continue; } op.Parameters.Add(new() { - Name = header.Name, - Value = GetValueType(header.Value), + Name = header.Key, + Value = GetValueType(string.Join(", ", header.Value)), In = ParameterLocation.Header }); } @@ -374,7 +378,7 @@ private void ProcessRequestHeaders(Request httpRequest, Operation op) Logger.LogTrace("Left {Name}", nameof(ProcessRequestHeaders)); } - private async Task ProcessResponseAsync(Response? httpResponse, TypeSpecFile doc, Operation op, string lastSegment, Uri url) + private async Task ProcessResponseAsync(HttpResponseMessage? httpResponse, TypeSpecFile doc, Operation op, string lastSegment, Uri url) { Logger.LogTrace("Entered {Name}", nameof(ProcessResponseAsync)); @@ -390,7 +394,7 @@ private async Task ProcessResponseAsync(Response? httpResponse, TypeSpecFile doc { res = new() { - StatusCode = httpResponse.StatusCode, + StatusCode = (int)httpResponse.StatusCode, BodyType = "string" }; } @@ -398,16 +402,17 @@ private async Task ProcessResponseAsync(Response? httpResponse, TypeSpecFile doc { res = new() { - StatusCode = httpResponse.StatusCode, + StatusCode = (int)httpResponse.StatusCode, Headers = httpResponse.Headers - .Where(h => !Http.StandardHeaders.Contains(h.Name.ToLowerInvariant()) && - !Http.AuthHeaders.Contains(h.Name.ToLowerInvariant())) - .ToDictionary(h => h.Name.ToCamelCase(), h => h.Value.GetType().Name) + .Where(h => !Http.StandardHeaders.Contains(h.Key.ToLowerInvariant()) && + !Http.AuthHeaders.Contains(h.Key.ToLowerInvariant())) + .ToDictionary(h => h.Key.ToCamelCase(), h => string.Join(", ", h.Value).GetType().Name) }; - if (httpResponse.HasBody) + if (httpResponse.Content is not null) { - var models = await GetModelsFromStringAsync(httpResponse.BodyString, lastSegment.ToPascalCase(), httpResponse.StatusCode >= 400); + var responseBody = await httpResponse.Content.ReadAsStringAsync(); + var models = await GetModelsFromStringAsync(responseBody, lastSegment.ToPascalCase(), (int)httpResponse.StatusCode >= 400); if (models.Length > 0) { foreach (var model in models) @@ -419,7 +424,7 @@ private async Task ProcessResponseAsync(Response? httpResponse, TypeSpecFile doc if (rootModel.IsArray) { res.BodyType = $"{rootModel.Name}[]"; - op.Name = await GetOperationNameAsync("list", url); + op.Name = GetOperationName("list", url); } else { @@ -438,11 +443,11 @@ private async Task ProcessResponseAsync(Response? httpResponse, TypeSpecFile doc Logger.LogTrace("Left {Name}", nameof(ProcessResponseAsync)); } - private async Task GetParameterNameAsync(Model model) + private string GetParameterNameAsync(Model model) { Logger.LogTrace("Entered {Name}", nameof(GetParameterNameAsync)); - var name = model.IsArray ? SanitizeName(await MakeSingularAsync(model.Name)) : model.Name; + var name = model.IsArray ? SanitizeName(MakeSingularAsync(model.Name)) : model.Name; if (string.IsNullOrEmpty(name)) { name = model.Name; @@ -506,15 +511,15 @@ private string GetRootNamespaceName(Uri url) return ns; } - private async Task GetOperationNameAsync(string method, Uri url) + private string GetOperationName(string method, Uri url) { - Logger.LogTrace("Entered {Name}", nameof(GetOperationNameAsync)); + Logger.LogTrace("Entered {Name}", nameof(GetOperationName)); var lastSegment = GetLastNonParametrizableSegment(url); Logger.LogDebug("Url: {Url}", url); Logger.LogDebug("Last non-parametrizable segment: {LastSegment}", lastSegment); - var name = method == "list" ? lastSegment : await MakeSingularAsync(lastSegment); + var name = method == "list" ? lastSegment : MakeSingularAsync(lastSegment); if (string.IsNullOrEmpty(name)) { name = lastSegment; @@ -538,7 +543,7 @@ private async Task GetOperationNameAsync(string method, Uri url) } Logger.LogDebug("Operation name: {OperationName}", operationName); - Logger.LogTrace("Left {Name}", nameof(GetOperationNameAsync)); + Logger.LogTrace("Left {Name}", nameof(GetOperationName)); return operationName; } @@ -608,113 +613,42 @@ private string GetLastNonParametrizableSegment(Uri url) for (var i = segments.Length - 1; i >= 0; i--) { var segment = segments[i].Trim('/'); - if (!IsParametrizable(segment)) + Logger.LogDebug("Segment: {Segment}", segment); + if (string.IsNullOrEmpty(segment) || segment.StartsWith('{')) { - Logger.LogDebug("Last non-parametrizable segment: {Segment}", segment); - Logger.LogTrace("Left {Name}", nameof(GetLastNonParametrizableSegment)); - - return segment; + continue; } + + Logger.LogTrace("Left {Name}", nameof(GetLastNonParametrizableSegment)); + return segment; } - Logger.LogDebug("No non-parametrizable segment found, returning empty string"); Logger.LogTrace("Left {Name}", nameof(GetLastNonParametrizableSegment)); - return string.Empty; } - private bool IsParametrizable(string segment) - { - Logger.LogTrace("Entered {Name}", nameof(IsParametrizable)); - - var isParametrizable = Guid.TryParse(segment, out _) || - int.TryParse(segment, out _); - - Logger.LogDebug("Is segment '{Segment}' parametrizable? {IsParametrizable}", segment, isParametrizable); - Logger.LogTrace("Left {Name}", nameof(IsParametrizable)); - - return isParametrizable; - } - - private async Task<(string Route, Parameter[] Parameters)> GetRouteAndParametersAsync(Uri url) + private string SanitizeName(string name) { - Logger.LogTrace("Entered {Name}", nameof(GetRouteAndParametersAsync)); - - var route = new List(); - var parameters = new List(); - var previousSegment = "item"; + Logger.LogTrace("Entered {Name}", nameof(SanitizeName)); - foreach (var segment in url.Segments) + if (string.IsNullOrEmpty(name)) { - Logger.LogDebug("Processing segment: {Segment}", segment); - - var segmentTrimmed = segment.Trim('/'); - if (string.IsNullOrEmpty(segmentTrimmed)) - { - continue; - } - - if (IsParametrizable(segmentTrimmed)) - { - var paramName = $"{previousSegment}Id"; - parameters.Add(new() - { - Name = paramName, - Value = GetValueType(segmentTrimmed), - In = ParameterLocation.Path - }); - route.Add($"{{{paramName}}}"); - } - else - { - previousSegment = SanitizeName(await MakeSingularAsync(segmentTrimmed)); - if (string.IsNullOrEmpty(previousSegment)) - { - previousSegment = SanitizeName(segmentTrimmed); - if (previousSegment.Length == 0) - { - previousSegment = GetRandomName(); - } - } - previousSegment = previousSegment.ToCamelCase(); - route.Add(segmentTrimmed); - } + Logger.LogTrace("Left {Name}", nameof(SanitizeName)); + return string.Empty; } - if (url.Query.Length > 0) - { - Logger.LogDebug("Processing query string: {Query}", url.Query); - - var query = HttpUtility.ParseQueryString(url.Query); - foreach (string key in query.Keys) - { - if (Http.AuthHeaders.Contains(key.ToLowerInvariant())) - { - Logger.LogDebug("Skipping auth header: {Key}", key); - continue; - } - - parameters.Add(new() - { - Name = key.ToCamelFromKebabCase(), - Value = GetValueType(query[key]), - In = ParameterLocation.Query - }); - } - } - else - { - Logger.LogDebug("No query string found in URL: {Url}", url); - } + // remove invalid characters + name = Regex.Replace(name, @"[^a-zA-Z0-9_]", "_", RegexOptions.Compiled); + Logger.LogDebug("Sanitized name: {Name}", name); - Logger.LogTrace("Left {Name}", nameof(GetRouteAndParametersAsync)); + Logger.LogTrace("Left {Name}", nameof(SanitizeName)); - return (string.Join('/', route), parameters.ToArray()); + return name; } private async Task GetModelsFromStringAsync(string? str, string name, bool isError = false) { - Logger.LogTrace("Entered {Same}", nameof(GetModelsFromStringAsync)); + Logger.LogTrace("Entered {Name}", nameof(GetModelsFromStringAsync)); if (string.IsNullOrEmpty(str)) { @@ -733,7 +667,7 @@ private async Task GetModelsFromStringAsync(string? str, string name, b } catch (Exception ex) { - Logger.LogDebug("Failed to parse JSON string, returning empty model list. Exception: {Ex}", ex.Message); + Logger.LogDebug("Failed to parse JSON token: {Ex}", ex.Message); // If the string is not a valid JSON, we return an empty model list Logger.LogTrace("Left {Name}", nameof(GetModelsFromStringAsync)); @@ -749,9 +683,7 @@ private async Task AddModelFromJsonElementAsync(JsonElement jsonElement, { Logger.LogTrace("Entered {Name}", nameof(AddModelFromJsonElementAsync)); -#pragma warning disable IDE0010 switch (jsonElement.ValueKind) -#pragma warning restore IDE0010 { case JsonValueKind.String: return "string"; @@ -773,7 +705,7 @@ private async Task AddModelFromJsonElementAsync(JsonElement jsonElement, var model = new Model { - Name = await GetModelNameAsync(name), + Name = GetModelName(name), IsError = isError }; @@ -806,16 +738,18 @@ private async Task AddModelFromJsonElementAsync(JsonElement jsonElement, return $"{modelName}[]"; case JsonValueKind.Null: return "null"; + case JsonValueKind.Undefined: + return string.Empty; default: return string.Empty; } } - private async Task GetModelNameAsync(string name) + private string GetModelName(string name) { - Logger.LogTrace("Entered {Name}", nameof(GetModelNameAsync)); + Logger.LogTrace("Entered {Name}", nameof(GetModelName)); - var modelName = SanitizeName(await MakeSingularAsync(name)); + var modelName = SanitizeName(MakeSingularAsync(name)); if (string.IsNullOrEmpty(modelName)) { modelName = SanitizeName(name); @@ -828,42 +762,26 @@ private async Task GetModelNameAsync(string name) modelName = modelName.ToPascalCase(); Logger.LogDebug("Model name: {ModelName}", modelName); - Logger.LogTrace("Left {Name}", nameof(GetModelNameAsync)); + Logger.LogTrace("Left {Name}", nameof(GetModelName)); return modelName; } - - private async Task MakeSingularAsync(string noun, CancellationToken cancellationToken = default) + private string MakeSingularAsync(string noun) { Logger.LogTrace("Entered {Name}", nameof(MakeSingularAsync)); - var singularNoun = await languageModelClient.GenerateChatCompletionAsync("singular_noun", new() + var singular = noun; + if (noun.EndsWith("ies", StringComparison.OrdinalIgnoreCase)) { - { "noun", noun } - }, cancellationToken); - var singular = singularNoun?.Response; - - if (string.IsNullOrEmpty(singular) || - singular.Contains(' ', StringComparison.OrdinalIgnoreCase)) + singular = noun[0..^3] + 'y'; + } + else if (noun.EndsWith("es", StringComparison.OrdinalIgnoreCase)) { - if (noun.EndsWith("ies", StringComparison.OrdinalIgnoreCase)) - { - singular = noun[0..^3] + 'y'; - } - else if (noun.EndsWith("es", StringComparison.OrdinalIgnoreCase)) - { - singular = noun[0..^2]; - } - else if (noun.EndsWith('s') && !noun.EndsWith("ss", StringComparison.OrdinalIgnoreCase)) - { - singular = noun[0..^1]; - } - else - { - singular = noun; - } - - Logger.LogDebug("Failed to get singular form of {Noun} from LLM. Using fallback: {Singular}", noun, singular); + singular = noun[0..^2]; + } + else if (noun.EndsWith('s') || noun.EndsWith('S')) + { + singular = noun[0..^1]; } Logger.LogDebug("Singular form of '{Noun}': {Singular}", noun, singular); @@ -872,18 +790,6 @@ private async Task MakeSingularAsync(string noun, CancellationToken canc return singular; } - private string SanitizeName(string name) - { - Logger.LogTrace("Entered {Name}", nameof(SanitizeName)); - - var sanitized = Regex.Replace(name, "[^a-zA-Z0-9_]", ""); - - Logger.LogDebug("Sanitized name: {Name} to: {Sanitized}", name, sanitized); - Logger.LogTrace("Left {Name}", nameof(SanitizeName)); - - return sanitized; - } - private string GetValueType(string? value) { Logger.LogTrace("Entered {Name}", nameof(GetValueType)); @@ -911,4 +817,110 @@ private string GetValueType(string? value) return "string"; } + + private string GetJsonType(JsonElement element) + { + Logger.LogTrace("Entered {Name}", nameof(GetJsonType)); + + return element.ValueKind switch + { + JsonValueKind.Object => "object", + JsonValueKind.Array => "array", + JsonValueKind.String => "string", + JsonValueKind.Number => element.TryGetInt32(out _) ? "int32" : "float64", + JsonValueKind.True or JsonValueKind.False => "boolean", + JsonValueKind.Undefined or JsonValueKind.Null => "null", + _ => "string" + }; + } + + private (string Route, Parameter[] Parameters) GetRouteAndParametersAsync(Uri url) + { + Logger.LogTrace("Entered {Name}", nameof(GetRouteAndParametersAsync)); + + var route = new List(); + var parameters = new List(); + var previousSegment = "item"; + + foreach (var segment in url.Segments) + { + Logger.LogDebug("Processing segment: {Segment}", segment); + + var segmentTrimmed = segment.Trim('/'); + if (string.IsNullOrEmpty(segmentTrimmed)) + { + continue; + } + + if (IsParametrizable(segmentTrimmed)) + { + var paramName = $"{previousSegment}Id"; + parameters.Add(new() + { + Name = paramName, + Value = GetValueType(segmentTrimmed), + In = ParameterLocation.Path + }); + route.Add($"{{{paramName}}}"); + } + else + { + previousSegment = SanitizeName(MakeSingularAsync(segmentTrimmed)); + if (string.IsNullOrEmpty(previousSegment)) + { + previousSegment = SanitizeName(segmentTrimmed); + if (previousSegment.Length == 0) + { + previousSegment = GetRandomName(); + } + } + previousSegment = previousSegment.ToCamelCase(); + route.Add(segmentTrimmed); + } + } + + if (url.Query.Length > 0) + { + Logger.LogDebug("Processing query string: {Query}", url.Query); + + + var query = HttpUtility.ParseQueryString(url.Query); + foreach (string key in query.Keys) + { + if (Http.AuthHeaders.Contains(key.ToLowerInvariant())) + { + Logger.LogDebug("Skipping auth header: {Key}", key); + continue; + } + + parameters.Add(new() + { + Name = key.ToCamelFromKebabCase(), + Value = GetValueType(query[key]), + In = ParameterLocation.Query + }); + } + } + else + { + Logger.LogDebug("No query string found in URL: {Url}", url); + } + + Logger.LogTrace("Left {Name}", nameof(GetRouteAndParametersAsync)); + + return (string.Join('/', route), parameters.ToArray()); + } + + private bool IsParametrizable(string segment) + { + Logger.LogTrace("Entered {Name}", nameof(IsParametrizable)); + + var isParametrizable = Guid.TryParse(segment, out _) || + int.TryParse(segment, out _); + + Logger.LogDebug("Is segment '{Segment}' parametrizable? {IsParametrizable}", segment, isParametrizable); + Logger.LogTrace("Left {Name}", nameof(IsParametrizable)); + + return isParametrizable; + } } \ No newline at end of file diff --git a/DevProxy.Plugins/Guidance/CachingGuidancePlugin.cs b/DevProxy.Plugins/Guidance/CachingGuidancePlugin.cs index dc4fefe8..7e46479c 100644 --- a/DevProxy.Plugins/Guidance/CachingGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/CachingGuidancePlugin.cs @@ -4,9 +4,9 @@ using DevProxy.Abstractions.Proxy; using DevProxy.Abstractions.Plugins; +using DevProxy.Abstractions.Utils; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; -using Titanium.Web.Proxy.Http; namespace DevProxy.Plugins.Guidance; @@ -32,32 +32,32 @@ public sealed class CachingGuidancePlugin( public override string Name => nameof(CachingGuidancePlugin); - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); + ArgumentNullException.ThrowIfNull(args); - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method == HttpMethod.Options) { - Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, args.Request); return Task.CompletedTask; } - var request = e.Session.HttpClient.Request; - var url = request.RequestUri.AbsoluteUri; + var request = args.Request; + var url = request.RequestUri!.AbsoluteUri; var now = DateTime.Now; if (!_interceptedRequests.TryGetValue(url, out var value)) { value = now; _interceptedRequests.Add(url, value); - Logger.LogRequest("First request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("First request", MessageType.Skipped, args.Request); return Task.CompletedTask; } @@ -65,19 +65,19 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca var secondsSinceLastIntercepted = (now - lastIntercepted).TotalSeconds; if (secondsSinceLastIntercepted <= Configuration.CacheThresholdSeconds) { - Logger.LogRequest(BuildCacheWarningMessage(request, Configuration.CacheThresholdSeconds, lastIntercepted), MessageType.Warning, new LoggingContext(e.Session)); + Logger.LogRequest(BuildCacheWarningMessage(request, Configuration.CacheThresholdSeconds, lastIntercepted), MessageType.Warning, args.Request); } else { - Logger.LogRequest("Request outside of cache window", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Request outside of cache window", MessageType.Skipped, args.Request); } _interceptedRequests[url] = now; - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); + Logger.LogTrace("Left {Name}", nameof(ProvideRequestGuidanceAsync)); return Task.CompletedTask; - } + }; - private static string BuildCacheWarningMessage(Request r, int _warningSeconds, DateTime lastIntercepted) => - $"Another request to {r.RequestUri.PathAndQuery} intercepted within {_warningSeconds} seconds. Last intercepted at {lastIntercepted}. Consider using cache to avoid calling the API too often."; + private static string BuildCacheWarningMessage(HttpRequestMessage r, int warningSeconds, DateTime lastIntercepted) => + $"Another request to {r.RequestUri!.PathAndQuery} intercepted within {warningSeconds} seconds. Last intercepted at {lastIntercepted}. Consider using cache to avoid calling the API too often."; } diff --git a/DevProxy.Plugins/Guidance/GraphBetaSupportGuidancePlugin.cs b/DevProxy.Plugins/Guidance/GraphBetaSupportGuidancePlugin.cs index 9f345941..81e1da56 100644 --- a/DevProxy.Plugins/Guidance/GraphBetaSupportGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/GraphBetaSupportGuidancePlugin.cs @@ -15,33 +15,33 @@ public sealed class GraphBetaSupportGuidancePlugin( { public override string Name => nameof(GraphBetaSupportGuidancePlugin); - public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(AfterResponseAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); + ArgumentNullException.ThrowIfNull(args); - var request = e.Session.HttpClient.Request; - if (!e.HasRequestUrlMatch(UrlsToWatch)) + var request = args.Request; + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method == HttpMethod.Options) { - Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, args.Request); return Task.CompletedTask; } if (!ProxyUtils.IsGraphBetaRequest(request)) { - Logger.LogRequest("Not a Microsoft Graph beta request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Not a Microsoft Graph beta request", MessageType.Skipped, args.Request); return Task.CompletedTask; } - Logger.LogRequest(BuildBetaSupportMessage(), MessageType.Warning, new(e.Session)); - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); + Logger.LogRequest(BuildBetaSupportMessage(), MessageType.Warning, args.Request); + Logger.LogTrace("Left {Name}", nameof(ProvideRequestGuidanceAsync)); return Task.CompletedTask; - } + }; private static string GetBetaSupportGuidanceUrl() => "https://aka.ms/devproxy/guidance/beta-support"; private static string BuildBetaSupportMessage() => diff --git a/DevProxy.Plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs b/DevProxy.Plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs index 5e610009..c8f12c12 100644 --- a/DevProxy.Plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/GraphClientRequestIdGuidancePlugin.cs @@ -7,7 +7,6 @@ using DevProxy.Abstractions.Utils; using DevProxy.Plugins.Utils; using Microsoft.Extensions.Logging; -using Titanium.Web.Proxy.Http; namespace DevProxy.Plugins.Guidance; @@ -17,45 +16,45 @@ public sealed class GraphClientRequestIdGuidancePlugin( { public override string Name => nameof(GraphClientRequestIdGuidancePlugin); - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); + ArgumentNullException.ThrowIfNull(args); - var request = e.Session.HttpClient.Request; - if (!e.HasRequestUrlMatch(UrlsToWatch)) + var request = args.Request; + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method == HttpMethod.Options) { - Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, args.Request); return Task.CompletedTask; } if (WarnNoClientRequestId(request)) { - Logger.LogRequest(BuildAddClientRequestIdMessage(), MessageType.Warning, new(e.Session)); + Logger.LogRequest(BuildAddClientRequestIdMessage(), MessageType.Warning, args.Request); if (!ProxyUtils.IsSdkRequest(request)) { - Logger.LogRequest(MessageUtils.BuildUseSdkMessage(), MessageType.Tip, new(e.Session)); + Logger.LogRequest(MessageUtils.BuildUseSdkMessage(), MessageType.Tip, args.Request); } } else { - Logger.LogRequest("client-request-id header present", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("client-request-id header present", MessageType.Skipped, args.Request); } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); + Logger.LogTrace("Left {Name}", nameof(ProvideRequestGuidanceAsync)); return Task.CompletedTask; - } + }; - private static bool WarnNoClientRequestId(Request request) => + private static bool WarnNoClientRequestId(HttpRequestMessage request) => ProxyUtils.IsGraphRequest(request) && - !request.Headers.HeaderExists("client-request-id"); + !request.Headers.Contains("client-request-id"); private static string GetClientRequestIdGuidanceUrl() => "https://aka.ms/devproxy/guidance/client-request-id"; private static string BuildAddClientRequestIdMessage() => diff --git a/DevProxy.Plugins/Guidance/GraphConnectorGuidancePlugin.cs b/DevProxy.Plugins/Guidance/GraphConnectorGuidancePlugin.cs index cfb5a425..29f4d45b 100644 --- a/DevProxy.Plugins/Guidance/GraphConnectorGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/GraphConnectorGuidancePlugin.cs @@ -34,37 +34,42 @@ public sealed class GraphConnectorGuidancePlugin( { public override string Name => nameof(GraphConnectorGuidancePlugin); - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); + ArgumentNullException.ThrowIfNull(args); - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return; } - if (!string.Equals(e.Session.HttpClient.Request.Method, "PATCH", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method != HttpMethod.Patch) { - Logger.LogRequest("Skipping non-PATCH request", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Skipping non-PATCH request", MessageType.Skipped, args.Request); + return; } try { - var schemaString = e.Session.HttpClient.Request.BodyString; + var schemaString = string.Empty; + if (args.Request.Content is not null) + { + schemaString = await args.Request.Content.ReadAsStringAsync(cancellationToken); + } + if (string.IsNullOrEmpty(schemaString)) { - Logger.LogRequest("No schema found in the request body.", MessageType.Failed, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("No schema found in the request body.", MessageType.Failed, args.Request); + return; } var schema = JsonSerializer.Deserialize(schemaString, ProxyUtils.JsonSerializerOptions); if (schema is null || schema.Properties is null) { - Logger.LogRequest("Invalid schema found in the request body.", MessageType.Failed, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Invalid schema found in the request body.", MessageType.Failed, args.Request); + return; } bool hasTitle = false, hasIconUrl = false, hasUrl = false; @@ -99,12 +104,12 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca Logger.LogRequest( $"The schema is missing the following semantic labels: {string.Join(", ", missingLabels.Where(s => !string.IsNullOrEmpty(s)))}. Ingested content might not show up in Microsoft Copilot for Microsoft 365. More information: https://aka.ms/devproxy/guidance/gc/ux", - MessageType.Failed, new(e.Session) + MessageType.Failed, args.Request ); } else { - Logger.LogRequest("The schema contains all the required semantic labels.", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("The schema contains all the required semantic labels.", MessageType.Skipped, args.Request); } } catch (Exception ex) @@ -112,7 +117,6 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca Logger.LogError(ex, "An error has occurred while deserializing the request body"); } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + Logger.LogTrace("Left {Name}", nameof(ProvideRequestGuidanceAsync)); + }; } diff --git a/DevProxy.Plugins/Guidance/GraphSdkGuidancePlugin.cs b/DevProxy.Plugins/Guidance/GraphSdkGuidancePlugin.cs index fae1070d..8546cbb5 100644 --- a/DevProxy.Plugins/Guidance/GraphSdkGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/GraphSdkGuidancePlugin.cs @@ -7,7 +7,6 @@ using DevProxy.Abstractions.Utils; using DevProxy.Plugins.Utils; using Microsoft.Extensions.Logging; -using Titanium.Web.Proxy.Http; namespace DevProxy.Plugins.Guidance; @@ -17,45 +16,42 @@ public sealed class GraphSdkGuidancePlugin( { public override string Name => nameof(GraphSdkGuidancePlugin); - public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + public override Func? ProvideResponseGuidanceAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(AfterResponseAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideResponseGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); - - var request = e.Session.HttpClient.Request; - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method == HttpMethod.Options) { - Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, args.Request); return Task.CompletedTask; } // only show the message if there is an error. - if (e.Session.HttpClient.Response.StatusCode >= 400) + if ((int)args.Response.StatusCode >= 400) { - if (WarnNoSdk(request)) + if (WarnNoSdk(args.Request)) { - Logger.LogRequest(MessageUtils.BuildUseSdkForErrorsMessage(), MessageType.Tip, new(e.Session)); + Logger.LogRequest(MessageUtils.BuildUseSdkForErrorsMessage(), MessageType.Tip, args.Request); } else { - Logger.LogRequest("Request issued using SDK", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Request issued using SDK", MessageType.Skipped, args.Request); } } else { - Logger.LogRequest("Skipping non-error response", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping non-error response", MessageType.Skipped, args.Request); } - Logger.LogTrace("Left {Name}", nameof(AfterResponseAsync)); + Logger.LogTrace("Left {Name}", nameof(ProvideResponseGuidanceAsync)); return Task.CompletedTask; - } + }; - private static bool WarnNoSdk(Request request) => + private static bool WarnNoSdk(HttpRequestMessage request) => ProxyUtils.IsGraphRequest(request) && !ProxyUtils.IsSdkRequest(request); } diff --git a/DevProxy.Plugins/Guidance/GraphSelectGuidancePlugin.cs b/DevProxy.Plugins/Guidance/GraphSelectGuidancePlugin.cs index 9f41574e..059bfbfb 100644 --- a/DevProxy.Plugins/Guidance/GraphSelectGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/GraphSelectGuidancePlugin.cs @@ -8,7 +8,6 @@ using DevProxy.Abstractions.Utils; using Microsoft.Extensions.Logging; using System.Globalization; -using Titanium.Web.Proxy.EventArguments; namespace DevProxy.Plugins.Guidance; @@ -29,53 +28,53 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell _ = _msGraphDb.GenerateDbAsync(true, cancellationToken); } - public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(AfterResponseAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); + ArgumentNullException.ThrowIfNull(args); - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method == HttpMethod.Options) { - Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (WarnNoSelect(e.Session)) + if (WarnNoSelect(args.Request)) { - Logger.LogRequest(BuildUseSelectMessage(), MessageType.Warning, new(e.Session)); + Logger.LogRequest(BuildUseSelectMessage(), MessageType.Warning, args.Request); } - Logger.LogTrace("Left {Name}", nameof(AfterResponseAsync)); + Logger.LogTrace("Left {Name}", nameof(ProvideRequestGuidanceAsync)); return Task.CompletedTask; - } + }; - private bool WarnNoSelect(SessionEventArgs session) + private bool WarnNoSelect(HttpRequestMessage request) { - var request = session.HttpClient.Request; if (!ProxyUtils.IsGraphRequest(request) || - request.Method != "GET") + request.Method != HttpMethod.Get) { - Logger.LogRequest("Not a Microsoft Graph GET request", MessageType.Skipped, new(session)); + Logger.LogRequest("Not a Microsoft Graph GET request", MessageType.Skipped, request); return false; } - var graphVersion = ProxyUtils.GetGraphVersion(request.RequestUri.AbsoluteUri); + var graphVersion = ProxyUtils.GetGraphVersion(request.RequestUri!.AbsoluteUri); var tokenizedUrl = GetTokenizedUrl(request.RequestUri.AbsoluteUri); if (EndpointSupportsSelect(graphVersion, tokenizedUrl)) { - return !request.Url.Contains("$select", StringComparison.OrdinalIgnoreCase) && - !request.Url.Contains("%24select", StringComparison.OrdinalIgnoreCase); + var url = request.RequestUri.AbsoluteUri; + return !url.Contains("$select", StringComparison.OrdinalIgnoreCase) && + !url.Contains("%24select", StringComparison.OrdinalIgnoreCase); } else { - Logger.LogRequest("Endpoint does not support $select", MessageType.Skipped, new(session)); + Logger.LogRequest("Endpoint does not support $select", MessageType.Skipped, request); return false; } } diff --git a/DevProxy.Plugins/Guidance/ODSPSearchGuidancePlugin.cs b/DevProxy.Plugins/Guidance/ODSPSearchGuidancePlugin.cs index 7cdb52b2..b8e80adf 100644 --- a/DevProxy.Plugins/Guidance/ODSPSearchGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/ODSPSearchGuidancePlugin.cs @@ -6,7 +6,6 @@ using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Utils; using Microsoft.Extensions.Logging; -using Titanium.Web.Proxy.EventArguments; namespace DevProxy.Plugins.Guidance; @@ -16,39 +15,36 @@ public sealed class ODSPSearchGuidancePlugin( { public override string Name => nameof(ODSPSearchGuidancePlugin); - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (string.Equals(e.Session.HttpClient.Request.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method == HttpMethod.Options) { - Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping OPTIONS request", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (WarnDeprecatedSearch(e.Session)) + if (WarnDeprecatedSearch(args.Request)) { - Logger.LogRequest(BuildUseGraphSearchMessage(), MessageType.Warning, new LoggingContext(e.Session)); + Logger.LogRequest(BuildUseGraphSearchMessage(), MessageType.Warning, args.Request); } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); + Logger.LogTrace("Left {Name}", nameof(ProvideRequestGuidanceAsync)); return Task.CompletedTask; - } + }; - private bool WarnDeprecatedSearch(SessionEventArgs session) + private bool WarnDeprecatedSearch(HttpRequestMessage request) { - var request = session.HttpClient.Request; if (!ProxyUtils.IsGraphRequest(request) || - request.Method != "GET") + request.Method != HttpMethod.Get) { - Logger.LogRequest("Not a Microsoft Graph GET request", MessageType.Skipped, new(session)); + Logger.LogRequest("Not a Microsoft Graph GET request", MessageType.Skipped, request); return false; } @@ -58,15 +54,16 @@ private bool WarnDeprecatedSearch(SessionEventArgs session) // graph.microsoft.com/{version}/sites/{site-id}/drive/root/search(q='{search-text}') // graph.microsoft.com/{version}/users/{user-id}/drive/root/search(q='{search-text}') // graph.microsoft.com/{version}/sites?search={query} - if (request.RequestUri.AbsolutePath.Contains("/search(q=", StringComparison.OrdinalIgnoreCase) || + if (request.RequestUri != null && + (request.RequestUri.AbsolutePath.Contains("/search(q=", StringComparison.OrdinalIgnoreCase) || (request.RequestUri.AbsolutePath.EndsWith("/sites", StringComparison.OrdinalIgnoreCase) && - request.RequestUri.Query.Contains("search=", StringComparison.OrdinalIgnoreCase))) + request.RequestUri.Query.Contains("search=", StringComparison.OrdinalIgnoreCase)))) { return true; } else { - Logger.LogRequest("Not a SharePoint search request", MessageType.Skipped, new(session)); + Logger.LogRequest("Not a SharePoint search request", MessageType.Skipped, request); return false; } } diff --git a/DevProxy.Plugins/Guidance/ODataPagingGuidancePlugin.cs b/DevProxy.Plugins/Guidance/ODataPagingGuidancePlugin.cs index 733948f8..18dc3c76 100644 --- a/DevProxy.Plugins/Guidance/ODataPagingGuidancePlugin.cs +++ b/DevProxy.Plugins/Guidance/ODataPagingGuidancePlugin.cs @@ -19,89 +19,89 @@ public sealed class ODataPagingGuidancePlugin( public override string Name => nameof(ODataPagingGuidancePlugin); - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (!string.Equals(e.Session.HttpClient.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method != HttpMethod.Get) { - Logger.LogRequest("Skipping non-GET request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping non-GET request", MessageType.Skipped, args.Request); return Task.CompletedTask; } - if (IsODataPagingUrl(e.Session.HttpClient.Request.RequestUri)) + if (args.Request.RequestUri != null && IsODataPagingUrl(args.Request.RequestUri)) { - if (!pagingUrls.Contains(e.Session.HttpClient.Request.Url)) + if (!pagingUrls.Contains(args.Request.RequestUri.ToString())) { - Logger.LogRequest(BuildIncorrectPagingUrlMessage(), MessageType.Warning, new(e.Session)); + Logger.LogRequest(BuildIncorrectPagingUrlMessage(), MessageType.Warning, args.Request); } else { - Logger.LogRequest("Paging URL is correct", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Paging URL is correct", MessageType.Skipped, args.Request); } } else { - Logger.LogRequest("Not an OData paging URL", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Not an OData paging URL", MessageType.Skipped, args.Request); } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); + Logger.LogTrace("Left {Name}", nameof(ProvideRequestGuidanceAsync)); return Task.CompletedTask; - } + }; - public override async Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + public override Func? ProvideResponseGuidanceAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeResponseAsync)); + Logger.LogTrace("{Method} called", nameof(ProvideResponseGuidanceAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return; } - if (!string.Equals(e.Session.HttpClient.Request.Method, "GET", StringComparison.OrdinalIgnoreCase)) + if (args.Request.Method != HttpMethod.Get) { - Logger.LogRequest("Skipping non-GET request", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping non-GET request", MessageType.Skipped, args.Request); return; } - if (e.Session.HttpClient.Response.StatusCode >= 300) + if ((int)args.Response.StatusCode >= 300) { - Logger.LogRequest("Skipping non-success response", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping non-success response", MessageType.Skipped, args.Request); return; } - if (e.Session.HttpClient.Response.ContentType is null || - (!e.Session.HttpClient.Response.ContentType.Contains("json", StringComparison.OrdinalIgnoreCase) && - !e.Session.HttpClient.Response.ContentType.Contains("application/atom+xml", StringComparison.OrdinalIgnoreCase)) || - !e.Session.HttpClient.Response.HasBody) + + var mediaType = args.Response.Content?.Headers?.ContentType?.MediaType; + if (mediaType is null || + (!mediaType.Contains("json", StringComparison.OrdinalIgnoreCase) && + !mediaType.Contains("application/atom+xml", StringComparison.OrdinalIgnoreCase))) { - Logger.LogRequest("Skipping response with unsupported body type", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping response with unsupported body type", MessageType.Skipped, args.Request); return; } - e.Session.HttpClient.Response.KeepBody = true; + if (args.Response.Content is null) + { + Logger.LogRequest("Skipping response with no content", MessageType.Skipped, args.Request); + return; + } var nextLink = string.Empty; - var bodyString = await e.Session.GetResponseBodyAsString(cancellationToken); + var bodyString = await args.Response.Content.ReadAsStringAsync(cancellationToken); if (string.IsNullOrEmpty(bodyString)) { - Logger.LogRequest("Skipping empty response body", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Skipping empty response body", MessageType.Skipped, args.Request); return; } - var contentType = e.Session.HttpClient.Response.ContentType; - if (contentType.Contains("json", StringComparison.OrdinalIgnoreCase)) + if (mediaType.Contains("json", StringComparison.OrdinalIgnoreCase)) { nextLink = GetNextLinkFromJson(bodyString); } - else if (contentType.Contains("application/atom+xml", StringComparison.OrdinalIgnoreCase)) + else if (mediaType.Contains("application/atom+xml", StringComparison.OrdinalIgnoreCase)) { nextLink = GetNextLinkFromXml(bodyString); } @@ -112,11 +112,11 @@ public override async Task BeforeResponseAsync(ProxyResponseArgs e, Cancellation } else { - Logger.LogRequest("No next link found in the response", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("No next link found in the response", MessageType.Skipped, args.Request); } - Logger.LogTrace("Left {Name}", nameof(BeforeResponseAsync)); - } + Logger.LogTrace("Left {Name}", nameof(ProvideResponseGuidanceAsync)); + }; private string GetNextLinkFromJson(string responseBody) { diff --git a/DevProxy.Plugins/Inspection/DevToolsPlugin.cs b/DevProxy.Plugins/Inspection/DevToolsPlugin.cs index fbd50491..ab28b446 100644 --- a/DevProxy.Plugins/Inspection/DevToolsPlugin.cs +++ b/DevProxy.Plugins/Inspection/DevToolsPlugin.cs @@ -13,7 +13,6 @@ using System.Net.Sockets; using System.Runtime.InteropServices; using System.Text.Json; -using System.Globalization; using System.Text; namespace DevProxy.Plugins.Inspection; @@ -68,27 +67,35 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell InitInspector(); } - public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func? ProvideRequestGuidanceAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); - - ArgumentNullException.ThrowIfNull(e); + Logger.LogTrace("{Method} called", nameof(ProvideRequestGuidanceAsync)); if (_webSocket?.IsConnected != true) { return; } - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); return; } - var requestId = GetRequestId(e.Session.HttpClient.Request); - var headers = e.Session.HttpClient.Request.Headers - .GroupBy(h => h.Name) - .ToDictionary(g => g.Key, g => string.Join(", ", g.Select(h => h.Value))); + var requestId = args.RequestId!; + var headers = args.Request.Headers + .ToDictionary(h => h.Key, h => string.Join(", ", h.Value)); + + // Add content headers if they exist + if (args.Request.Content?.Headers != null) + { + foreach (var header in args.Request.Content.Headers) + { + headers[header.Key] = string.Join(", ", header.Value); + } + } + + var postData = args.Request.Content != null ? await args.Request.Content.ReadAsStringAsync(cancellationToken) : null; var requestWillBeSentMessage = new RequestWillBeSentMessage { @@ -96,13 +103,13 @@ public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationTo { RequestId = requestId, LoaderId = "1", - DocumentUrl = e.Session.HttpClient.Request.Url, + DocumentUrl = args.Request.RequestUri?.ToString() ?? string.Empty, Request = new() { - Url = e.Session.HttpClient.Request.Url, - Method = e.Session.HttpClient.Request.Method, + Url = args.Request.RequestUri?.ToString() ?? string.Empty, + Method = args.Request.Method.Method, Headers = headers, - PostData = e.Session.HttpClient.Request.HasBody ? e.Session.HttpClient.Request.BodyString : null + PostData = postData }, Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, WallTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(), @@ -126,24 +133,20 @@ public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationTo } }; await _webSocket.SendAsync(requestWillBeSentExtraInfoMessage, cancellationToken); - } + }; - public override async Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + public override Func? ProvideResponseGuidanceAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(AfterResponseAsync)); - - ArgumentNullException.ThrowIfNull(e); - - await base.AfterResponseAsync(e, cancellationToken); + Logger.LogTrace("{Method} called", nameof(ProvideResponseGuidanceAsync)); if (_webSocket?.IsConnected != true) { return; } - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); return; } @@ -152,22 +155,38 @@ public override async Task AfterResponseAsync(ProxyResponseArgs e, CancellationT Body = string.Empty, Base64Encoded = false }; - if (e.Session.HttpClient.Response.HasBody) + + if (args.Response.Content != null) { - if (IsTextResponse(e.Session.HttpClient.Response.ContentType)) + var contentType = args.Response.Content.Headers.ContentType?.MediaType; + if (IsTextResponse(contentType)) { - body.Body = e.Session.HttpClient.Response.BodyString; + body.Body = await args.Response.Content.ReadAsStringAsync(cancellationToken); body.Base64Encoded = false; } else { - body.Body = Convert.ToBase64String(e.Session.HttpClient.Response.Body); + var bytes = await args.Response.Content.ReadAsByteArrayAsync(cancellationToken); + body.Body = Convert.ToBase64String(bytes); body.Base64Encoded = true; } } - _responseBody.Add(e.Session.HttpClient.Request.GetHashCode().ToString(CultureInfo.InvariantCulture), body); - var requestId = GetRequestId(e.Session.HttpClient.Request); + _responseBody.Add(args.RequestId, body); + + var requestId = args.RequestId!; + + var responseHeaders = args.Response.Headers + .ToDictionary(h => h.Key, h => string.Join(", ", h.Value)); + + // Add content headers if they exist + if (args.Response.Content?.Headers != null) + { + foreach (var header in args.Response.Content.Headers) + { + responseHeaders[header.Key] = string.Join(", ", header.Value); + } + } var responseReceivedMessage = new ResponseReceivedMessage { @@ -179,13 +198,11 @@ public override async Task AfterResponseAsync(ProxyResponseArgs e, CancellationT Type = "XHR", Response = new() { - Url = e.Session.HttpClient.Request.Url, - Status = e.Session.HttpClient.Response.StatusCode, - StatusText = e.Session.HttpClient.Response.StatusDescription, - Headers = e.Session.HttpClient.Response.Headers - .GroupBy(h => h.Name) - .ToDictionary(g => g.Key, g => string.Join(", ", g.Select(h => h.Value))), - MimeType = e.Session.HttpClient.Response.ContentType + Url = args.Request.RequestUri?.ToString() ?? string.Empty, + Status = (int)args.Response.StatusCode, + StatusText = args.Response.ReasonPhrase ?? string.Empty, + Headers = responseHeaders, + MimeType = args.Response.Content?.Headers.ContentType?.MediaType ?? string.Empty }, HasExtraInfo = true } @@ -193,7 +210,7 @@ public override async Task AfterResponseAsync(ProxyResponseArgs e, CancellationT await _webSocket.SendAsync(responseReceivedMessage, cancellationToken); - if (e.Session.HttpClient.Response.ContentType == "text/event-stream") + if (args.Response.Content?.Headers.ContentType?.MediaType == "text/event-stream") { await SendBodyAsDataReceivedAsync(requestId, body.Body, cancellationToken); } @@ -204,13 +221,15 @@ public override async Task AfterResponseAsync(ProxyResponseArgs e, CancellationT { RequestId = requestId, Timestamp = (double)DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() / 1000, - EncodedDataLength = e.Session.HttpClient.Response.HasBody ? e.Session.HttpClient.Response.Body.Length : 0 + EncodedDataLength = args.Response.Content != null ? (await args.Response.Content.ReadAsByteArrayAsync(cancellationToken)).Length : 0 } }; await _webSocket.SendAsync(loadingFinishedMessage, cancellationToken); - } + }; - public override async Task AfterRequestLogAsync(RequestLogArgs e, CancellationToken cancellationToken) + public override Func? HandleRequestLogAsync => AfterRequestLogAsync; + + public async Task AfterRequestLogAsync(RequestLogArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRequestLogAsync)); @@ -233,8 +252,8 @@ public override async Task AfterRequestLogAsync(RequestLogArgs e, CancellationTo Text = string.Join(" ", e.RequestLog.Message), Level = Entry.GetLevel(e.RequestLog.MessageType), Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), - Url = e.RequestLog.Context?.Session.HttpClient.Request.Url, - NetworkRequestId = GetRequestId(e.RequestLog.Context?.Session.HttpClient.Request) + Url = e.RequestLog.Request?.RequestUri?.ToString(), + NetworkRequestId = e.RequestLog.RequestId! } } }; @@ -447,16 +466,6 @@ private static int GetFreePort() return port; } - private static string GetRequestId(Titanium.Web.Proxy.Http.Request? request) - { - if (request is null) - { - return string.Empty; - } - - return request.GetHashCode().ToString(CultureInfo.InvariantCulture); - } - private static bool IsTextResponse(string? contentType) { var isTextResponse = false; diff --git a/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs b/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs index 7707cdf5..0b6e66d0 100644 --- a/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs +++ b/DevProxy.Plugins/Inspection/OpenAITelemetryPlugin.cs @@ -3,23 +3,22 @@ // See the LICENSE file in the project root for more information. using DevProxy.Abstractions.LanguageModel; -using DevProxy.Abstractions.OpenTelemetry; -using DevProxy.Abstractions.Plugins; -using DevProxy.Abstractions.Proxy; -using DevProxy.Abstractions.Utils; -using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using OpenTelemetry; -using OpenTelemetry.Exporter; -using OpenTelemetry.Metrics; -using OpenTelemetry.Resources; -using OpenTelemetry.Trace; -using System.Collections.Concurrent; -using System.Diagnostics; -using System.Diagnostics.Metrics; -using System.Text.Json; -using Titanium.Web.Proxy.Http; +//using DevProxy.Abstractions.OpenTelemetry; +//using DevProxy.Abstractions.Plugins; +//using DevProxy.Abstractions.Proxy; +//using DevProxy.Abstractions.Utils; +//using Microsoft.Extensions.Configuration; +//using Microsoft.Extensions.DependencyInjection; +//using Microsoft.Extensions.Logging; +//using OpenTelemetry; +//using OpenTelemetry.Exporter; +//using OpenTelemetry.Metrics; +//using OpenTelemetry.Resources; +//using OpenTelemetry.Trace; +//using System.Collections.Concurrent; +//using System.Diagnostics; +//using System.Diagnostics.Metrics; +//using System.Text.Json; namespace DevProxy.Plugins.Inspection; @@ -42,1030 +41,1032 @@ public sealed class OpenAITelemetryPluginConfiguration : LanguageModelPricesPlug public bool IncludePrompt { get; set; } = true; } -public sealed class OpenAITelemetryPlugin( - HttpClient httpClient, - ILogger logger, - ISet urlsToWatch, - IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : - BaseReportingPlugin( - httpClient, - logger, - urlsToWatch, - proxyConfiguration, - pluginConfigurationSection), IDisposable -{ - private const string ActivitySourceName = "DevProxy.OpenAI"; - private const string OpenAISystem = "openai"; - - private static readonly Meter _meter = new(ActivitySourceName); - private static Histogram? _requestCostMetric; - private static Counter? _totalCostMetric; - private static Histogram? _tokenUsageMetric; - - private readonly ActivitySource _activitySource = new(ActivitySourceName); - private LanguageModelPricesLoader? _loader; - private MeterProvider? _meterProvider; - private TracerProvider? _tracerProvider; - private readonly ConcurrentDictionary> _modelUsage = []; - - public override string Name => nameof(OpenAITelemetryPlugin); - - public override async Task InitializeAsync(InitArgs e, CancellationToken cancellationToken) - { - ArgumentNullException.ThrowIfNull(e); - - await base.InitializeAsync(e, cancellationToken); - - if (Configuration.IncludeCosts) - { - Configuration.PricesFile = ProxyUtils.GetFullPath(Configuration.PricesFile, ProxyConfiguration.ConfigFile); - _loader = ActivatorUtilities.CreateInstance(e.ServiceProvider, Configuration); - await _loader.InitFileWatcherAsync(cancellationToken); - } - - InitializeOpenTelemetryExporter(); - } - - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) - { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); - - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) - { - Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); - return Task.CompletedTask; - } - - var request = e.Session.HttpClient.Request; - if (request.Method is null || - !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || - !request.HasBody) - { - Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - - if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest) || openAiRequest is null) - { - Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - - // store for use in response - e.SessionData["OpenAIRequest"] = openAiRequest; - - var activity = _activitySource.StartActivity( - $"openai.{GetOperationName(openAiRequest)}", - ActivityKind.Client); - - if (activity is null) - { - Logger.LogWarning("Failed to start OpenTelemetry activity for OpenAI request"); - return Task.CompletedTask; - } - - // add generic request tags - _ = activity.SetTag("http.method", request.Method) - .SetTag("http.url", request.RequestUri.ToString()) - .SetTag("http.scheme", request.RequestUri.Scheme) - .SetTag("http.host", request.RequestUri.Host) - .SetTag("http.target", request.RequestUri.PathAndQuery) - .SetTag(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem) - .SetTag(SemanticConvention.GEN_AI_ENVIRONMENT, Configuration.Environment) - .SetTag(SemanticConvention.GEN_AI_APPLICATION_NAME, Configuration.Application); - - AddCommonRequestTags(activity, openAiRequest); - AddRequestTypeSpecificTags(activity, openAiRequest); - - // store for use in response - e.SessionData["OpenAIActivity"] = activity; - - Logger.LogTrace("OnRequestAsync() finished"); - return Task.CompletedTask; - } - - public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) - { - Logger.LogTrace("{Method} called", nameof(AfterResponseAsync)); - - ArgumentNullException.ThrowIfNull(e); - - if (!e.SessionData.TryGetValue("OpenAIActivity", out var activityObj) || - activityObj is not Activity activity) - { - return Task.CompletedTask; - } - - try - { - var response = e.Session.HttpClient.Response; - - _ = activity.SetTag("http.status_code", response.StatusCode); - -#pragma warning disable IDE0010 - switch (response.StatusCode) -#pragma warning restore IDE0010 - { - case int code when code is >= 200 and < 300: - ProcessSuccessResponse(activity, e); - break; - case int code when code >= 400: - ProcessErrorResponse(activity, e); - break; - } - } - finally - { - activity.Stop(); - - // Clean up session data - _ = e.SessionData.Remove("OpenAIActivity"); - _ = e.SessionData.Remove("OpenAIRequest"); - - Logger.LogRequest("OpenTelemetry information emitted", MessageType.Processed, new(e.Session)); - } - - Logger.LogTrace("Left {Name}", nameof(AfterResponseAsync)); - return Task.CompletedTask; - } - - public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) - { - Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); - - var report = new OpenAITelemetryPluginReport - { - Application = Configuration.Application, - Environment = Configuration.Environment, - Currency = Configuration.Currency, - IncludeCosts = Configuration.IncludeCosts, - ModelUsage = _modelUsage.ToDictionary() - }; - - StoreReport(report, e); - _modelUsage.Clear(); - - Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); - return Task.CompletedTask; - } - - private void InitializeOpenTelemetryExporter() - { - Logger.LogTrace("InitializeOpenTelemetryExporter() called"); - - try - { - void configureOtlpExporter(OtlpExporterOptions options) - { - // We use protobuf to allow intercepting Dev Proxy's own LLM traffic - options.Protocol = OtlpExportProtocol.HttpProtobuf; - options.Endpoint = new Uri(Configuration.ExporterEndpoint + "/v1/traces"); - } - - var resourceBuilder = ResourceBuilder - .CreateDefault() - .AddService(serviceName: "DevProxy.OpenAI", serviceVersion: ProxyUtils.ProductVersion); - - _tracerProvider = Sdk.CreateTracerProviderBuilder() - .SetResourceBuilder(resourceBuilder) - .AddSource(ActivitySourceName) - .AddOtlpExporter(configureOtlpExporter) - .Build(); - - _meterProvider = Sdk.CreateMeterProviderBuilder() - .SetResourceBuilder(resourceBuilder) - .AddMeter(ActivitySourceName) - .AddView(SemanticConvention.GEN_AI_METRIC_CLIENT_TOKEN_USAGE, new ExplicitBucketHistogramConfiguration - { - Boundaries = [1, 4, 16, 64, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, 67108864] - }) - .AddView(SemanticConvention.GEN_AI_USAGE_COST, new ExplicitBucketHistogramConfiguration - { - Boundaries = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10, 50, 100] - }) - .AddView(SemanticConvention.GEN_AI_USAGE_TOTAL_COST, new MetricStreamConfiguration()) - .AddOtlpExporter(configureOtlpExporter) - .Build(); - - _tokenUsageMetric = _meter.CreateHistogram( - SemanticConvention.GEN_AI_METRIC_CLIENT_TOKEN_USAGE, - "tokens", - "Number of tokens processed"); - _requestCostMetric = _meter.CreateHistogram( - SemanticConvention.GEN_AI_USAGE_COST, - "cost", - $"Estimated cost per request in {Configuration.Currency}"); - _totalCostMetric = _meter.CreateCounter( - SemanticConvention.GEN_AI_USAGE_TOTAL_COST, - "cost", - $"Total estimated cost for the session in {Configuration.Currency}"); - } - catch (Exception ex) - { - Logger.LogError(ex, "Failed to initialize OpenTelemetry exporter"); - } - - Logger.LogTrace("InitializeOpenTelemetryExporter() finished"); - } - - private void ProcessErrorResponse(Activity activity, ProxyResponseArgs e) - { - Logger.LogTrace("ProcessErrorResponse() called"); - - var response = e.Session.HttpClient.Response; - - _ = activity.SetTag("error", true) - .SetTag("error.type", "http") - .SetTag("error.message", $"HTTP {response.StatusCode}"); - - if (response.HasBody) - { - try - { - var errorObj = JsonSerializer.Deserialize(response.BodyString); - if (errorObj.TryGetProperty("error", out var error)) - { - if (error.TryGetProperty("message", out var message)) - { - _ = activity.SetTag("error.details", message.GetString()); - } - } - } - catch (JsonException) - { - // Ignore JSON parsing errors in error responses - } - } - - Logger.LogTrace("ProcessErrorResponse() finished"); - } - - private void ProcessSuccessResponse(Activity activity, ProxyResponseArgs e) - { - Logger.LogTrace("ProcessSuccessResponse() called"); - - var response = e.Session.HttpClient.Response; - - if (!response.HasBody || string.IsNullOrEmpty(response.BodyString)) - { - Logger.LogDebug("Response body is empty or null"); - return; - } - - if (!e.SessionData.TryGetValue("OpenAIRequest", out var requestObj) || - requestObj is not OpenAIRequest openAiRequest) - { - Logger.LogDebug("OpenAI request not found in session data"); - return; - } - - var bodyString = response.BodyString; - if (IsStreamingResponse(response)) - { - bodyString = GetBodyFromStreamingResponse(response); - } - - AddResponseTypeSpecificTags(activity, openAiRequest, bodyString); - - Logger.LogTrace("ProcessSuccessResponse() finished"); - } - - private void AddResponseTypeSpecificTags(Activity activity, OpenAIRequest openAiRequest, string responseBody) - { - Logger.LogTrace("AddResponseTypeSpecificTags() called"); - - try - { - switch (openAiRequest) - { - case OpenAIChatCompletionRequest: - AddChatCompletionResponseTags(activity, openAiRequest, responseBody); - break; - case OpenAICompletionRequest: - AddCompletionResponseTags(activity, openAiRequest, responseBody); - break; - case OpenAIEmbeddingRequest: - AddEmbeddingResponseTags(activity, openAiRequest, responseBody); - break; - case OpenAIImageRequest: - AddImageResponseTags(activity, openAiRequest, responseBody); - break; - case OpenAIAudioRequest: - AddAudioResponseTags(activity, openAiRequest, responseBody); - break; - case OpenAIFineTuneRequest: - AddFineTuneResponseTags(activity, openAiRequest, responseBody); - break; - default: - throw new InvalidOperationException($"Unsupported OpenAI request type: {openAiRequest.GetType().Name}"); - } - } - catch (JsonException ex) - { - Logger.LogError(ex, "Failed to deserialize OpenAI response"); - _ = activity.SetTag("error", ex.Message); - } - } - - private void AddFineTuneResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) - { - Logger.LogTrace("AddFineTuneResponseTags() called"); - - var fineTuneResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); - if (fineTuneResponse is null) - { - return; - } - - RecordUsageMetrics(activity, openAIRequest, fineTuneResponse); - - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_STATUS, fineTuneResponse.Status) - .SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, fineTuneResponse.Id); - - if (!string.IsNullOrEmpty(fineTuneResponse.FineTunedModel)) - { - _ = activity.SetTag("ai.response.fine_tuned_model", fineTuneResponse.FineTunedModel); - } - - Logger.LogTrace("AddFineTuneResponseTags() finished"); - } - - private void AddAudioResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) - { - Logger.LogTrace("AddAudioResponseTags() called"); - - var audioResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); - if (audioResponse is null) - { - return; - } - - RecordUsageMetrics(activity, openAIRequest, audioResponse); - - // Record the transcription text if configured - if (Configuration.IncludeCompletion && !string.IsNullOrEmpty(audioResponse.Text)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_COMPLETION, audioResponse.Text); - } - - Logger.LogTrace("AddAudioResponseTags() finished"); - } - - private void AddImageResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) - { - Logger.LogTrace("AddImageResponseTags() called"); - - var imageResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); - if (imageResponse is null) - { - return; - } - - RecordUsageMetrics(activity, openAIRequest, imageResponse); - - _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, imageResponse.Id); - - if (imageResponse.Data != null) - { - _ = activity.SetTag("ai.response.image.count", imageResponse.Data.Count()); - - if (Configuration.IncludeCompletion && - imageResponse.Data.Any() && - !string.IsNullOrEmpty(imageResponse.Data.FirstOrDefault()?.RevisedPrompt)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_REVISED_PROMPT, - imageResponse.Data.First().RevisedPrompt); - } - } - - Logger.LogTrace("AddImageResponseTags() finished"); - } - - private void AddEmbeddingResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) - { - Logger.LogTrace("AddEmbeddingResponseTags() called"); - - var embeddingResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); - if (embeddingResponse is null) - { - return; - } - - RecordUsageMetrics(activity, openAIRequest, embeddingResponse); - - // Embedding response doesn't have a "completion" but we can record some metadata - _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, embeddingResponse.Id); - - if (embeddingResponse.Data is not null) - { - _ = activity.SetTag("ai.embedding.count", embeddingResponse.Data.Count()); - - // If there's only one embedding, record the dimensions - if (embeddingResponse.Data.Count() == 1 && - embeddingResponse.Data.First().Embedding is not null) - { - _ = activity.SetTag("ai.embedding.dimensions", embeddingResponse.Data.FirstOrDefault()?.Embedding?.Count() ?? 0); - } - } - - Logger.LogTrace("AddEmbeddingResponseTags() finished"); - } - - private void AddChatCompletionResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) - { - Logger.LogTrace("AddChatCompletionResponseTags() called"); - - var chatResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); - if (chatResponse is null) - { - return; - } - - RecordUsageMetrics(activity, openAIRequest, chatResponse); - - if (chatResponse.Choices?.FirstOrDefault()?.Message is not null) - { - if (Configuration.IncludeCompletion) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_COMPLETION, chatResponse.Choices.First().Message.Content); - } - - _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_FINISH_REASON, chatResponse.Choices.First().FinishReason); - } - - _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, chatResponse.Id); - - Logger.LogTrace("AddChatCompletionResponseTags() finished"); - } - - private void AddCompletionResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) - { - Logger.LogTrace("AddCompletionResponseTags() called"); - - var completionResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); - if (completionResponse is null) - { - return; - } - - RecordUsageMetrics(activity, openAIRequest, completionResponse); - - if (completionResponse.Choices?.FirstOrDefault() is not null) - { - if (Configuration.IncludeCompletion) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_COMPLETION, completionResponse.Choices.First().Text); - } - - _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_FINISH_REASON, completionResponse.Choices.First().FinishReason); - } - - _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, completionResponse.Id); - - Logger.LogTrace("AddCompletionResponseTags() finished"); - } - - private void AddRequestTypeSpecificTags(Activity activity, OpenAIRequest openAiRequest) - { - switch (openAiRequest) - { - case OpenAIChatCompletionRequest chatRequest: - AddChatCompletionRequestTags(activity, chatRequest); - break; - case OpenAICompletionRequest completionRequest: - AddCompletionRequestTags(activity, completionRequest); - break; - case OpenAIEmbeddingRequest embeddingRequest: - AddEmbeddingRequestTags(activity, embeddingRequest); - break; - case OpenAIImageRequest imageRequest: - AddImageRequestTags(activity, imageRequest); - break; - case OpenAIAudioRequest audioRequest: - AddAudioRequestTags(activity, audioRequest); - break; - case OpenAIAudioSpeechRequest speechRequest: - AddAudioSpeechRequestTags(activity, speechRequest); - break; - case OpenAIFineTuneRequest fineTuneRequest: - AddFineTuneRequestTags(activity, fineTuneRequest); - break; - default: - throw new InvalidOperationException($"Unsupported OpenAI request type: {openAiRequest.GetType().Name}"); - } - } - - private void AddCompletionRequestTags(Activity activity, OpenAICompletionRequest completionRequest) - { - Logger.LogTrace("AddCompletionRequestTags() called"); - - // OpenLIT - _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_CONTENT_COMPLETION) - // OpenTelemetry - .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_CONTENT_COMPLETION); - - if (Configuration.IncludePrompt) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, completionRequest.Prompt); - } - - Logger.LogTrace("AddCompletionRequestTags() finished"); - } - - private void AddChatCompletionRequestTags(Activity activity, OpenAIChatCompletionRequest chatRequest) - { - Logger.LogTrace("AddChatCompletionRequestTags() called"); - - // OpenLIT - _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_CHAT) - // OpenTelemetry - .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_CHAT); - - if (Configuration.IncludePrompt) - { - // Format messages to a more readable form for the span - var formattedMessages = chatRequest.Messages - .Select(m => $"{m.Role}: {m.Content}") - .ToArray(); - - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, string.Join("\n", formattedMessages)); - } - - Logger.LogTrace("AddChatCompletionRequestTags() finished"); - } - - private void AddEmbeddingRequestTags(Activity activity, OpenAIEmbeddingRequest embeddingRequest) - { - Logger.LogTrace("AddEmbeddingRequestTags() called"); - - // OpenLIT - _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_EMBEDDING) - // OpenTelemetry - .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_EMBEDDING); - - if (Configuration.IncludePrompt && embeddingRequest.Input is not null) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, embeddingRequest.Input); - } - - if (embeddingRequest.EncodingFormat is not null) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_ENCODING_FORMATS, embeddingRequest.EncodingFormat); - } - - if (embeddingRequest.Dimensions.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_EMBEDDING_DIMENSION, embeddingRequest.Dimensions.Value); - } - - Logger.LogTrace("AddEmbeddingRequestTags() finished"); - } - - private void AddImageRequestTags(Activity activity, OpenAIImageRequest imageRequest) - { - Logger.LogTrace("AddImageRequestTags() called"); - - // OpenLIT - _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_IMAGE) - // OpenTelemetry - .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_IMAGE); - - if (Configuration.IncludePrompt && !string.IsNullOrEmpty(imageRequest.Prompt)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, imageRequest.Prompt); - } - - if (!string.IsNullOrEmpty(imageRequest.Size)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IMAGE_SIZE, imageRequest.Size); - } - - if (!string.IsNullOrEmpty(imageRequest.Quality)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IMAGE_QUALITY, imageRequest.Quality); - } - - if (!string.IsNullOrEmpty(imageRequest.Style)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IMAGE_STYLE, imageRequest.Style); - } - - if (imageRequest.N.HasValue) - { - _ = activity.SetTag("ai.request.image.count", imageRequest.N.Value); - } - - Logger.LogTrace("AddImageRequestTags() finished"); - } - - private void AddAudioRequestTags(Activity activity, OpenAIAudioRequest audioRequest) - { - Logger.LogTrace("AddAudioRequestTags() called"); - - // OpenLIT - _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO) - // OpenTelemetry - .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO); - - if (!string.IsNullOrEmpty(audioRequest.ResponseFormat)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_RESPONSE_FORMAT, audioRequest.ResponseFormat); - } - - if (!string.IsNullOrEmpty(audioRequest.Prompt) && Configuration.IncludePrompt) - { - _ = activity.SetTag("ai.request.audio.prompt", audioRequest.Prompt); - } - - if (!string.IsNullOrEmpty(audioRequest.Language)) - { - _ = activity.SetTag("ai.request.audio.language", audioRequest.Language); - } - - Logger.LogTrace("AddAudioRequestTags() finished"); - } - - private void AddAudioSpeechRequestTags(Activity activity, OpenAIAudioSpeechRequest speechRequest) - { - Logger.LogTrace("AddAudioSpeechRequestTags() called"); - - // OpenLIT - _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO) - // OpenTelemetry - .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO); - - if (Configuration.IncludePrompt && !string.IsNullOrEmpty(speechRequest.Input)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, speechRequest.Input); - } - - if (!string.IsNullOrEmpty(speechRequest.Voice)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_VOICE, speechRequest.Voice); - } - - if (!string.IsNullOrEmpty(speechRequest.ResponseFormat)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_RESPONSE_FORMAT, speechRequest.ResponseFormat); - } - - if (speechRequest.Speed.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_SPEED, speechRequest.Speed.Value); - } - - Logger.LogTrace("AddAudioSpeechRequestTags() finished"); - } - - private void AddFineTuneRequestTags(Activity activity, OpenAIFineTuneRequest fineTuneRequest) - { - Logger.LogTrace("AddFineTuneRequestTags() called"); - - // OpenLIT - _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_FINETUNING) - // OpenTelemetry - .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_FINETUNING); - - if (!string.IsNullOrEmpty(fineTuneRequest.TrainingFile)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_TRAINING_FILE, fineTuneRequest.TrainingFile); - } - - if (!string.IsNullOrEmpty(fineTuneRequest.ValidationFile)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_VALIDATION_FILE, fineTuneRequest.ValidationFile); - } - - if (fineTuneRequest.BatchSize.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_BATCH_SIZE, fineTuneRequest.BatchSize.Value); - } - - if (fineTuneRequest.LearningRateMultiplier.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_MODEL_LRM, - fineTuneRequest.LearningRateMultiplier.Value); - } - - if (fineTuneRequest.Epochs.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_MODEL_EPOCHS, - fineTuneRequest.Epochs.Value); - } - - if (!string.IsNullOrEmpty(fineTuneRequest.Suffix)) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_MODEL_SUFFIX, - fineTuneRequest.Suffix); - } - - Logger.LogTrace("AddFineTuneRequestTags() finished"); - } - - private void AddCommonRequestTags(Activity activity, OpenAIRequest openAiRequest) - { - Logger.LogTrace("AddCommonRequestTags() called"); - - if (openAiRequest.Temperature.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_TEMPERATURE, - openAiRequest.Temperature.Value); - } - - if (openAiRequest.MaxTokens.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_MAX_TOKENS, - openAiRequest.MaxTokens.Value); - } - - if (openAiRequest.TopP.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_TOP_P, - openAiRequest.TopP.Value); - } - - if (openAiRequest.PresencePenalty.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_PRESENCE_PENALTY, - openAiRequest.PresencePenalty.Value); - } - - if (openAiRequest.FrequencyPenalty.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FREQUENCY_PENALTY, - openAiRequest.FrequencyPenalty.Value); - } - - if (openAiRequest.Stream.HasValue) - { - _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IS_STREAM, - openAiRequest.Stream.Value); - } - - Logger.LogTrace("AddCommonRequestTags() finished"); - } - - private void RecordUsageMetrics(Activity activity, OpenAIRequest request, OpenAIResponse response) - { - Logger.LogTrace("RecordUsageMetrics() called"); - - var usage = response.Usage; - if (usage is null) - { - return; - } - - Debug.Assert(_tokenUsageMetric is not null, "Token usage histogram is not initialized"); - _tokenUsageMetric.Record(response.Usage?.PromptTokens ?? 0, - [ - new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), - new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), - new(SemanticConvention.GEN_AI_TOKEN_TYPE, SemanticConvention.GEN_AI_TOKEN_TYPE_INPUT), - new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), - new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) - ]); - _tokenUsageMetric.Record(response.Usage?.CompletionTokens ?? 0, - [ - new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), - new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), - new(SemanticConvention.GEN_AI_TOKEN_TYPE, SemanticConvention.GEN_AI_TOKEN_TYPE_OUTPUT), - new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), - new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) - ]); - - _ = activity.SetTag(SemanticConvention.GEN_AI_USAGE_INPUT_TOKENS, usage.PromptTokens) - .SetTag(SemanticConvention.GEN_AI_USAGE_OUTPUT_TOKENS, usage.CompletionTokens) - .SetTag(SemanticConvention.GEN_AI_USAGE_TOTAL_TOKENS, usage.TotalTokens); - - var reportModelUsageInformation = new OpenAITelemetryPluginReportModelUsageInformation - { - Model = response.Model, - PromptTokens = usage.PromptTokens, - CompletionTokens = usage.CompletionTokens - }; - var usagePerModel = _modelUsage.GetOrAdd(response.Model, model => []); - usagePerModel.Add(reportModelUsageInformation); - - if (!Configuration.IncludeCosts || Configuration.Prices is null) - { - Logger.LogDebug("Cost tracking is disabled or prices data is not available"); - return; - } - - if (string.IsNullOrEmpty(response.Model)) - { - Logger.LogDebug("Response model is empty or null"); - return; - } - - var (inputCost, outputCost) = Configuration.Prices.CalculateCost(response.Model, usage.PromptTokens, usage.CompletionTokens); - - if (inputCost > 0) - { - var totalCost = inputCost + outputCost; - _ = activity.SetTag(SemanticConvention.GEN_AI_USAGE_COST, totalCost); - - Debug.Assert(_requestCostMetric is not null, "Cost histogram is not initialized"); - Debug.Assert(_totalCostMetric is not null, "Total cost counter is not initialized"); - - _requestCostMetric.Record(totalCost, - [ - new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), - new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), - new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), - new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) - ]); - _totalCostMetric.Add(totalCost, - [ - new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), - new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), - new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), - new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) - ]); - reportModelUsageInformation.Cost = totalCost; - } - else - { - Logger.LogDebug("Input cost is zero, skipping cost metrics recording"); - } - - Logger.LogTrace("RecordUsageMetrics() finished"); - } - - private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) - { - Logger.LogTrace("TryGetOpenAIRequest() called"); - - request = null; - - if (string.IsNullOrEmpty(content)) - { - Logger.LogDebug("Request content is empty or null"); - return false; - } - - try - { - Logger.LogDebug("Checking if the request is an OpenAI request..."); - - var rawRequest = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - - // Check for completion request (has "prompt", but not specific to image) - if (rawRequest.TryGetProperty("prompt", out _) && - !rawRequest.TryGetProperty("size", out _) && - !rawRequest.TryGetProperty("n", out _)) - { - Logger.LogDebug("Request is a completion request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - // Chat completion request - if (rawRequest.TryGetProperty("messages", out _)) - { - Logger.LogDebug("Request is a chat completion request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - // Embedding request - if (rawRequest.TryGetProperty("input", out _) && - rawRequest.TryGetProperty("model", out _) && - !rawRequest.TryGetProperty("voice", out _)) - { - Logger.LogDebug("Request is an embedding request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - // Image generation request - if (rawRequest.TryGetProperty("prompt", out _) && - (rawRequest.TryGetProperty("size", out _) || rawRequest.TryGetProperty("n", out _))) - { - Logger.LogDebug("Request is an image generation request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - // Audio transcription request - if (rawRequest.TryGetProperty("file", out _)) - { - Logger.LogDebug("Request is an audio transcription request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - // Audio speech synthesis request - if (rawRequest.TryGetProperty("input", out _) && rawRequest.TryGetProperty("voice", out _)) - { - Logger.LogDebug("Request is an audio speech synthesis request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - // Fine-tuning request - if (rawRequest.TryGetProperty("training_file", out _)) - { - Logger.LogDebug("Request is a fine-tuning request"); - request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); - return true; - } - - Logger.LogDebug("Request is not an OpenAI request."); - return false; - } - catch (JsonException ex) - { - Logger.LogDebug(ex, "Failed to deserialize OpenAI request."); - return false; - } - } - - private static string GetOperationName(OpenAIRequest request) - { - if (request == null) - { - return "unknown"; - } - - return request switch - { - OpenAIChatCompletionRequest => "chat.completions", - OpenAICompletionRequest => "completions", - OpenAIEmbeddingRequest => "embeddings", - OpenAIImageRequest => "images.generations", - OpenAIAudioRequest => "audio.transcriptions", - OpenAIAudioSpeechRequest => "audio.speech", - OpenAIFineTuneRequest => "fine_tuning.jobs", - _ => "unknown" - }; - } - - private bool IsStreamingResponse(Response response) - { - Logger.LogTrace("{Method} called", nameof(IsStreamingResponse)); - var contentType = response.Headers.FirstOrDefault(h => h.Name.Equals("content-type", StringComparison.OrdinalIgnoreCase))?.Value; - if (string.IsNullOrEmpty(contentType)) - { - Logger.LogDebug("No content-type header found"); - return false; - } - - var isStreamingResponse = contentType.Contains("text/event-stream", StringComparison.OrdinalIgnoreCase); - Logger.LogDebug("IsStreamingResponse: {IsStreamingResponse}", isStreamingResponse); - - Logger.LogTrace("{Method} finished", nameof(IsStreamingResponse)); - return isStreamingResponse; - } - - private string GetBodyFromStreamingResponse(Response response) - { - Logger.LogTrace("{Method} called", nameof(GetBodyFromStreamingResponse)); - - // default to the whole body - var bodyString = response.BodyString; - - var chunks = bodyString.Split("\n\n", StringSplitOptions.RemoveEmptyEntries); - if (chunks.Length == 0) - { - Logger.LogDebug("No chunks found in the response body"); - return bodyString; - } - - // check if the last chunk is `data: [DONE]` - var lastChunk = chunks.Last().Trim(); - if (lastChunk.Equals("data: [DONE]", StringComparison.OrdinalIgnoreCase)) - { - // get next to last chunk - var chunk = chunks.Length > 1 ? chunks[^2].Trim() : string.Empty; - if (chunk.StartsWith("data: ", StringComparison.OrdinalIgnoreCase)) - { - // remove the "data: " prefix - bodyString = chunk["data: ".Length..].Trim(); - Logger.LogDebug("Last chunk starts with 'data: ', using the last chunk as the body: {BodyString}", bodyString); - } - else - { - Logger.LogDebug("Last chunk does not start with 'data: ', using the whole body"); - } - } - else - { - Logger.LogDebug("Last chunk is not `data: [DONE]`, using the whole body"); - } - - Logger.LogTrace("{Method} finished", nameof(GetBodyFromStreamingResponse)); - return bodyString; - } - - public void Dispose() - { - _loader?.Dispose(); - _activitySource?.Dispose(); - _tracerProvider?.Dispose(); - _meterProvider?.Dispose(); - } -} +//public sealed class OpenAITelemetryPlugin( +// HttpClient httpClient, +// ILogger logger, +// ISet urlsToWatch, +// IProxyConfiguration proxyConfiguration, +// IConfigurationSection pluginConfigurationSection, +// IProxyStorage proxyStorage) : +// BaseReportingPlugin( +// httpClient, +// logger, +// urlsToWatch, +// proxyConfiguration, +// pluginConfigurationSection, +// proxyStorage), IDisposable +//{ +// private const string ActivitySourceName = "DevProxy.OpenAI"; +// private const string OpenAISystem = "openai"; + +// private static readonly Meter _meter = new(ActivitySourceName); +// private static Histogram? _requestCostMetric; +// private static Counter? _totalCostMetric; +// private static Histogram? _tokenUsageMetric; + +// private readonly ActivitySource _activitySource = new(ActivitySourceName); +// private LanguageModelPricesLoader? _loader; +// private MeterProvider? _meterProvider; +// private TracerProvider? _tracerProvider; +// private readonly ConcurrentDictionary> _modelUsage = []; + +// public override string Name => nameof(OpenAITelemetryPlugin); + +// public override async Task InitializeAsync(InitArgs e, CancellationToken cancellationToken) +// { +// ArgumentNullException.ThrowIfNull(e); + +// await base.InitializeAsync(e, cancellationToken); + +// if (Configuration.IncludeCosts) +// { +// Configuration.PricesFile = ProxyUtils.GetFullPath(Configuration.PricesFile, ProxyConfiguration.ConfigFile); +// _loader = ActivatorUtilities.CreateInstance(e.ServiceProvider, Configuration); +// await _loader.InitFileWatcherAsync(cancellationToken); +// } + +// InitializeOpenTelemetryExporter(); +// } + +// public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) +// { +// Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + +// ArgumentNullException.ThrowIfNull(e); + +// if (!e.HasRequestUrlMatch(UrlsToWatch)) +// { +// Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); +// return Task.CompletedTask; +// } + +// var request = e.Session.HttpClient.Request; +// if (request.Method is null || +// !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || +// !request.HasBody) +// { +// Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, new(e.Session)); +// return Task.CompletedTask; +// } + +// if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest) || openAiRequest is null) +// { +// Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, new(e.Session)); +// return Task.CompletedTask; +// } + +// // store for use in response +// e.SessionData["OpenAIRequest"] = openAiRequest; + +// var activity = _activitySource.StartActivity( +// $"openai.{GetOperationName(openAiRequest)}", +// ActivityKind.Client); + +// if (activity is null) +// { +// Logger.LogWarning("Failed to start OpenTelemetry activity for OpenAI request"); +// return Task.CompletedTask; +// } + +// // add generic request tags +// _ = activity.SetTag("http.method", request.Method) +// .SetTag("http.url", request.RequestUri.ToString()) +// .SetTag("http.scheme", request.RequestUri.Scheme) +// .SetTag("http.host", request.RequestUri.Host) +// .SetTag("http.target", request.RequestUri.PathAndQuery) +// .SetTag(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem) +// .SetTag(SemanticConvention.GEN_AI_ENVIRONMENT, Configuration.Environment) +// .SetTag(SemanticConvention.GEN_AI_APPLICATION_NAME, Configuration.Application); + +// AddCommonRequestTags(activity, openAiRequest); +// AddRequestTypeSpecificTags(activity, openAiRequest); + +// // store for use in response +// e.SessionData["OpenAIActivity"] = activity; + +// Logger.LogTrace("OnRequestAsync() finished"); +// return Task.CompletedTask; +// } + +// public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) +// { +// Logger.LogTrace("{Method} called", nameof(AfterResponseAsync)); + +// ArgumentNullException.ThrowIfNull(e); + +// if (!e.SessionData.TryGetValue("OpenAIActivity", out var activityObj) || +// activityObj is not Activity activity) +// { +// return Task.CompletedTask; +// } + +// try +// { +// var response = e.Session.HttpClient.Response; + +// _ = activity.SetTag("http.status_code", response.StatusCode); + +//#pragma warning disable IDE0010 +// switch (response.StatusCode) +//#pragma warning restore IDE0010 +// { +// case int code when code is >= 200 and < 300: +// ProcessSuccessResponse(activity, e); +// break; +// case int code when code >= 400: +// ProcessErrorResponse(activity, e); +// break; +// } +// } +// finally +// { +// activity.Stop(); + +// // Clean up session data +// _ = e.SessionData.Remove("OpenAIActivity"); +// _ = e.SessionData.Remove("OpenAIRequest"); + +// Logger.LogRequest("OpenTelemetry information emitted", MessageType.Processed, new(e.Session)); +// } + +// Logger.LogTrace("Left {Name}", nameof(AfterResponseAsync)); +// return Task.CompletedTask; +// } + +// public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) +// { +// Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); + +// var report = new OpenAITelemetryPluginReport +// { +// Application = Configuration.Application, +// Environment = Configuration.Environment, +// Currency = Configuration.Currency, +// IncludeCosts = Configuration.IncludeCosts, +// ModelUsage = _modelUsage.ToDictionary() +// }; + +// StoreReport(report); +// _modelUsage.Clear(); + +// Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); +// return Task.CompletedTask; +// } + +// private void InitializeOpenTelemetryExporter() +// { +// Logger.LogTrace("InitializeOpenTelemetryExporter() called"); + +// try +// { +// void configureOtlpExporter(OtlpExporterOptions options) +// { +// // We use protobuf to allow intercepting Dev Proxy's own LLM traffic +// options.Protocol = OtlpExportProtocol.HttpProtobuf; +// options.Endpoint = new Uri(Configuration.ExporterEndpoint + "/v1/traces"); +// } + +// var resourceBuilder = ResourceBuilder +// .CreateDefault() +// .AddService(serviceName: "DevProxy.OpenAI", serviceVersion: ProxyUtils.ProductVersion); + +// _tracerProvider = Sdk.CreateTracerProviderBuilder() +// .SetResourceBuilder(resourceBuilder) +// .AddSource(ActivitySourceName) +// .AddOtlpExporter(configureOtlpExporter) +// .Build(); + +// _meterProvider = Sdk.CreateMeterProviderBuilder() +// .SetResourceBuilder(resourceBuilder) +// .AddMeter(ActivitySourceName) +// .AddView(SemanticConvention.GEN_AI_METRIC_CLIENT_TOKEN_USAGE, new ExplicitBucketHistogramConfiguration +// { +// Boundaries = [1, 4, 16, 64, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, 67108864] +// }) +// .AddView(SemanticConvention.GEN_AI_USAGE_COST, new ExplicitBucketHistogramConfiguration +// { +// Boundaries = [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10, 50, 100] +// }) +// .AddView(SemanticConvention.GEN_AI_USAGE_TOTAL_COST, new MetricStreamConfiguration()) +// .AddOtlpExporter(configureOtlpExporter) +// .Build(); + +// _tokenUsageMetric = _meter.CreateHistogram( +// SemanticConvention.GEN_AI_METRIC_CLIENT_TOKEN_USAGE, +// "tokens", +// "Number of tokens processed"); +// _requestCostMetric = _meter.CreateHistogram( +// SemanticConvention.GEN_AI_USAGE_COST, +// "cost", +// $"Estimated cost per request in {Configuration.Currency}"); +// _totalCostMetric = _meter.CreateCounter( +// SemanticConvention.GEN_AI_USAGE_TOTAL_COST, +// "cost", +// $"Total estimated cost for the session in {Configuration.Currency}"); +// } +// catch (Exception ex) +// { +// Logger.LogError(ex, "Failed to initialize OpenTelemetry exporter"); +// } + +// Logger.LogTrace("InitializeOpenTelemetryExporter() finished"); +// } + +// private void ProcessErrorResponse(Activity activity, ProxyResponseArgs e) +// { +// Logger.LogTrace("ProcessErrorResponse() called"); + +// var response = e.Session.HttpClient.Response; + +// _ = activity.SetTag("error", true) +// .SetTag("error.type", "http") +// .SetTag("error.message", $"HTTP {response.StatusCode}"); + +// if (response.HasBody) +// { +// try +// { +// var errorObj = JsonSerializer.Deserialize(response.BodyString); +// if (errorObj.TryGetProperty("error", out var error)) +// { +// if (error.TryGetProperty("message", out var message)) +// { +// _ = activity.SetTag("error.details", message.GetString()); +// } +// } +// } +// catch (JsonException) +// { +// // Ignore JSON parsing errors in error responses +// } +// } + +// Logger.LogTrace("ProcessErrorResponse() finished"); +// } + +// private void ProcessSuccessResponse(Activity activity, ProxyResponseArgs e) +// { +// Logger.LogTrace("ProcessSuccessResponse() called"); + +// var response = e.Session.HttpClient.Response; + +// if (!response.HasBody || string.IsNullOrEmpty(response.BodyString)) +// { +// Logger.LogDebug("Response body is empty or null"); +// return; +// } + +// if (!e.SessionData.TryGetValue("OpenAIRequest", out var requestObj) || +// requestObj is not OpenAIRequest openAiRequest) +// { +// Logger.LogDebug("OpenAI request not found in session data"); +// return; +// } + +// var bodyString = response.BodyString; +// if (IsStreamingResponse(response)) +// { +// bodyString = GetBodyFromStreamingResponse(response); +// } + +// AddResponseTypeSpecificTags(activity, openAiRequest, bodyString); + +// Logger.LogTrace("ProcessSuccessResponse() finished"); +// } + +// private void AddResponseTypeSpecificTags(Activity activity, OpenAIRequest openAiRequest, string responseBody) +// { +// Logger.LogTrace("AddResponseTypeSpecificTags() called"); + +// try +// { +// switch (openAiRequest) +// { +// case OpenAIChatCompletionRequest: +// AddChatCompletionResponseTags(activity, openAiRequest, responseBody); +// break; +// case OpenAICompletionRequest: +// AddCompletionResponseTags(activity, openAiRequest, responseBody); +// break; +// case OpenAIEmbeddingRequest: +// AddEmbeddingResponseTags(activity, openAiRequest, responseBody); +// break; +// case OpenAIImageRequest: +// AddImageResponseTags(activity, openAiRequest, responseBody); +// break; +// case OpenAIAudioRequest: +// AddAudioResponseTags(activity, openAiRequest, responseBody); +// break; +// case OpenAIFineTuneRequest: +// AddFineTuneResponseTags(activity, openAiRequest, responseBody); +// break; +// default: +// throw new InvalidOperationException($"Unsupported OpenAI request type: {openAiRequest.GetType().Name}"); +// } +// } +// catch (JsonException ex) +// { +// Logger.LogError(ex, "Failed to deserialize OpenAI response"); +// _ = activity.SetTag("error", ex.Message); +// } +// } + +// private void AddFineTuneResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) +// { +// Logger.LogTrace("AddFineTuneResponseTags() called"); + +// var fineTuneResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); +// if (fineTuneResponse is null) +// { +// return; +// } + +// RecordUsageMetrics(activity, openAIRequest, fineTuneResponse); + +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_STATUS, fineTuneResponse.Status) +// .SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, fineTuneResponse.Id); + +// if (!string.IsNullOrEmpty(fineTuneResponse.FineTunedModel)) +// { +// _ = activity.SetTag("ai.response.fine_tuned_model", fineTuneResponse.FineTunedModel); +// } + +// Logger.LogTrace("AddFineTuneResponseTags() finished"); +// } + +// private void AddAudioResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) +// { +// Logger.LogTrace("AddAudioResponseTags() called"); + +// var audioResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); +// if (audioResponse is null) +// { +// return; +// } + +// RecordUsageMetrics(activity, openAIRequest, audioResponse); + +// // Record the transcription text if configured +// if (Configuration.IncludeCompletion && !string.IsNullOrEmpty(audioResponse.Text)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_COMPLETION, audioResponse.Text); +// } + +// Logger.LogTrace("AddAudioResponseTags() finished"); +// } + +// private void AddImageResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) +// { +// Logger.LogTrace("AddImageResponseTags() called"); + +// var imageResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); +// if (imageResponse is null) +// { +// return; +// } + +// RecordUsageMetrics(activity, openAIRequest, imageResponse); + +// _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, imageResponse.Id); + +// if (imageResponse.Data != null) +// { +// _ = activity.SetTag("ai.response.image.count", imageResponse.Data.Count()); + +// if (Configuration.IncludeCompletion && +// imageResponse.Data.Any() && +// !string.IsNullOrEmpty(imageResponse.Data.FirstOrDefault()?.RevisedPrompt)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_REVISED_PROMPT, +// imageResponse.Data.First().RevisedPrompt); +// } +// } + +// Logger.LogTrace("AddImageResponseTags() finished"); +// } + +// private void AddEmbeddingResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) +// { +// Logger.LogTrace("AddEmbeddingResponseTags() called"); + +// var embeddingResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); +// if (embeddingResponse is null) +// { +// return; +// } + +// RecordUsageMetrics(activity, openAIRequest, embeddingResponse); + +// // Embedding response doesn't have a "completion" but we can record some metadata +// _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, embeddingResponse.Id); + +// if (embeddingResponse.Data is not null) +// { +// _ = activity.SetTag("ai.embedding.count", embeddingResponse.Data.Count()); + +// // If there's only one embedding, record the dimensions +// if (embeddingResponse.Data.Count() == 1 && +// embeddingResponse.Data.First().Embedding is not null) +// { +// _ = activity.SetTag("ai.embedding.dimensions", embeddingResponse.Data.FirstOrDefault()?.Embedding?.Count() ?? 0); +// } +// } + +// Logger.LogTrace("AddEmbeddingResponseTags() finished"); +// } + +// private void AddChatCompletionResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) +// { +// Logger.LogTrace("AddChatCompletionResponseTags() called"); + +// var chatResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); +// if (chatResponse is null) +// { +// return; +// } + +// RecordUsageMetrics(activity, openAIRequest, chatResponse); + +// if (chatResponse.Choices?.FirstOrDefault()?.Message is not null) +// { +// if (Configuration.IncludeCompletion) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_COMPLETION, chatResponse.Choices.First().Message.Content); +// } + +// _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_FINISH_REASON, chatResponse.Choices.First().FinishReason); +// } + +// _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, chatResponse.Id); + +// Logger.LogTrace("AddChatCompletionResponseTags() finished"); +// } + +// private void AddCompletionResponseTags(Activity activity, OpenAIRequest openAIRequest, string responseBody) +// { +// Logger.LogTrace("AddCompletionResponseTags() called"); + +// var completionResponse = JsonSerializer.Deserialize(responseBody, ProxyUtils.JsonSerializerOptions); +// if (completionResponse is null) +// { +// return; +// } + +// RecordUsageMetrics(activity, openAIRequest, completionResponse); + +// if (completionResponse.Choices?.FirstOrDefault() is not null) +// { +// if (Configuration.IncludeCompletion) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_COMPLETION, completionResponse.Choices.First().Text); +// } + +// _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_FINISH_REASON, completionResponse.Choices.First().FinishReason); +// } + +// _ = activity.SetTag(SemanticConvention.GEN_AI_RESPONSE_ID, completionResponse.Id); + +// Logger.LogTrace("AddCompletionResponseTags() finished"); +// } + +// private void AddRequestTypeSpecificTags(Activity activity, OpenAIRequest openAiRequest) +// { +// switch (openAiRequest) +// { +// case OpenAIChatCompletionRequest chatRequest: +// AddChatCompletionRequestTags(activity, chatRequest); +// break; +// case OpenAICompletionRequest completionRequest: +// AddCompletionRequestTags(activity, completionRequest); +// break; +// case OpenAIEmbeddingRequest embeddingRequest: +// AddEmbeddingRequestTags(activity, embeddingRequest); +// break; +// case OpenAIImageRequest imageRequest: +// AddImageRequestTags(activity, imageRequest); +// break; +// case OpenAIAudioRequest audioRequest: +// AddAudioRequestTags(activity, audioRequest); +// break; +// case OpenAIAudioSpeechRequest speechRequest: +// AddAudioSpeechRequestTags(activity, speechRequest); +// break; +// case OpenAIFineTuneRequest fineTuneRequest: +// AddFineTuneRequestTags(activity, fineTuneRequest); +// break; +// default: +// throw new InvalidOperationException($"Unsupported OpenAI request type: {openAiRequest.GetType().Name}"); +// } +// } + +// private void AddCompletionRequestTags(Activity activity, OpenAICompletionRequest completionRequest) +// { +// Logger.LogTrace("AddCompletionRequestTags() called"); + +// // OpenLIT +// _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_CONTENT_COMPLETION) +// // OpenTelemetry +// .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_CONTENT_COMPLETION); + +// if (Configuration.IncludePrompt) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, completionRequest.Prompt); +// } + +// Logger.LogTrace("AddCompletionRequestTags() finished"); +// } + +// private void AddChatCompletionRequestTags(Activity activity, OpenAIChatCompletionRequest chatRequest) +// { +// Logger.LogTrace("AddChatCompletionRequestTags() called"); + +// // OpenLIT +// _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_CHAT) +// // OpenTelemetry +// .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_CHAT); + +// if (Configuration.IncludePrompt) +// { +// // Format messages to a more readable form for the span +// var formattedMessages = chatRequest.Messages +// .Select(m => $"{m.Role}: {m.Content}") +// .ToArray(); + +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, string.Join("\n", formattedMessages)); +// } + +// Logger.LogTrace("AddChatCompletionRequestTags() finished"); +// } + +// private void AddEmbeddingRequestTags(Activity activity, OpenAIEmbeddingRequest embeddingRequest) +// { +// Logger.LogTrace("AddEmbeddingRequestTags() called"); + +// // OpenLIT +// _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_EMBEDDING) +// // OpenTelemetry +// .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_EMBEDDING); + +// if (Configuration.IncludePrompt && embeddingRequest.Input is not null) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, embeddingRequest.Input); +// } + +// if (embeddingRequest.EncodingFormat is not null) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_ENCODING_FORMATS, embeddingRequest.EncodingFormat); +// } + +// if (embeddingRequest.Dimensions.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_EMBEDDING_DIMENSION, embeddingRequest.Dimensions.Value); +// } + +// Logger.LogTrace("AddEmbeddingRequestTags() finished"); +// } + +// private void AddImageRequestTags(Activity activity, OpenAIImageRequest imageRequest) +// { +// Logger.LogTrace("AddImageRequestTags() called"); + +// // OpenLIT +// _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_IMAGE) +// // OpenTelemetry +// .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_IMAGE); + +// if (Configuration.IncludePrompt && !string.IsNullOrEmpty(imageRequest.Prompt)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, imageRequest.Prompt); +// } + +// if (!string.IsNullOrEmpty(imageRequest.Size)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IMAGE_SIZE, imageRequest.Size); +// } + +// if (!string.IsNullOrEmpty(imageRequest.Quality)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IMAGE_QUALITY, imageRequest.Quality); +// } + +// if (!string.IsNullOrEmpty(imageRequest.Style)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IMAGE_STYLE, imageRequest.Style); +// } + +// if (imageRequest.N.HasValue) +// { +// _ = activity.SetTag("ai.request.image.count", imageRequest.N.Value); +// } + +// Logger.LogTrace("AddImageRequestTags() finished"); +// } + +// private void AddAudioRequestTags(Activity activity, OpenAIAudioRequest audioRequest) +// { +// Logger.LogTrace("AddAudioRequestTags() called"); + +// // OpenLIT +// _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO) +// // OpenTelemetry +// .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO); + +// if (!string.IsNullOrEmpty(audioRequest.ResponseFormat)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_RESPONSE_FORMAT, audioRequest.ResponseFormat); +// } + +// if (!string.IsNullOrEmpty(audioRequest.Prompt) && Configuration.IncludePrompt) +// { +// _ = activity.SetTag("ai.request.audio.prompt", audioRequest.Prompt); +// } + +// if (!string.IsNullOrEmpty(audioRequest.Language)) +// { +// _ = activity.SetTag("ai.request.audio.language", audioRequest.Language); +// } + +// Logger.LogTrace("AddAudioRequestTags() finished"); +// } + +// private void AddAudioSpeechRequestTags(Activity activity, OpenAIAudioSpeechRequest speechRequest) +// { +// Logger.LogTrace("AddAudioSpeechRequestTags() called"); + +// // OpenLIT +// _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO) +// // OpenTelemetry +// .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_AUDIO); + +// if (Configuration.IncludePrompt && !string.IsNullOrEmpty(speechRequest.Input)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_CONTENT_PROMPT, speechRequest.Input); +// } + +// if (!string.IsNullOrEmpty(speechRequest.Voice)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_VOICE, speechRequest.Voice); +// } + +// if (!string.IsNullOrEmpty(speechRequest.ResponseFormat)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_RESPONSE_FORMAT, speechRequest.ResponseFormat); +// } + +// if (speechRequest.Speed.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_AUDIO_SPEED, speechRequest.Speed.Value); +// } + +// Logger.LogTrace("AddAudioSpeechRequestTags() finished"); +// } + +// private void AddFineTuneRequestTags(Activity activity, OpenAIFineTuneRequest fineTuneRequest) +// { +// Logger.LogTrace("AddFineTuneRequestTags() called"); + +// // OpenLIT +// _ = activity.SetTag(SemanticConvention.GEN_AI_OPERATION, SemanticConvention.GEN_AI_OPERATION_TYPE_FINETUNING) +// // OpenTelemetry +// .SetTag(SemanticConvention.GEN_AI_OPERATION_NAME, SemanticConvention.GEN_AI_OPERATION_TYPE_FINETUNING); + +// if (!string.IsNullOrEmpty(fineTuneRequest.TrainingFile)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_TRAINING_FILE, fineTuneRequest.TrainingFile); +// } + +// if (!string.IsNullOrEmpty(fineTuneRequest.ValidationFile)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_VALIDATION_FILE, fineTuneRequest.ValidationFile); +// } + +// if (fineTuneRequest.BatchSize.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_BATCH_SIZE, fineTuneRequest.BatchSize.Value); +// } + +// if (fineTuneRequest.LearningRateMultiplier.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_MODEL_LRM, +// fineTuneRequest.LearningRateMultiplier.Value); +// } + +// if (fineTuneRequest.Epochs.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_MODEL_EPOCHS, +// fineTuneRequest.Epochs.Value); +// } + +// if (!string.IsNullOrEmpty(fineTuneRequest.Suffix)) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FINETUNE_MODEL_SUFFIX, +// fineTuneRequest.Suffix); +// } + +// Logger.LogTrace("AddFineTuneRequestTags() finished"); +// } + +// private void AddCommonRequestTags(Activity activity, OpenAIRequest openAiRequest) +// { +// Logger.LogTrace("AddCommonRequestTags() called"); + +// if (openAiRequest.Temperature.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_TEMPERATURE, +// openAiRequest.Temperature.Value); +// } + +// if (openAiRequest.MaxTokens.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_MAX_TOKENS, +// openAiRequest.MaxTokens.Value); +// } + +// if (openAiRequest.TopP.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_TOP_P, +// openAiRequest.TopP.Value); +// } + +// if (openAiRequest.PresencePenalty.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_PRESENCE_PENALTY, +// openAiRequest.PresencePenalty.Value); +// } + +// if (openAiRequest.FrequencyPenalty.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_FREQUENCY_PENALTY, +// openAiRequest.FrequencyPenalty.Value); +// } + +// if (openAiRequest.Stream.HasValue) +// { +// _ = activity.SetTag(SemanticConvention.GEN_AI_REQUEST_IS_STREAM, +// openAiRequest.Stream.Value); +// } + +// Logger.LogTrace("AddCommonRequestTags() finished"); +// } + +// private void RecordUsageMetrics(Activity activity, OpenAIRequest request, OpenAIResponse response) +// { +// Logger.LogTrace("RecordUsageMetrics() called"); + +// var usage = response.Usage; +// if (usage is null) +// { +// return; +// } + +// Debug.Assert(_tokenUsageMetric is not null, "Token usage histogram is not initialized"); +// _tokenUsageMetric.Record(response.Usage?.PromptTokens ?? 0, +// [ +// new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), +// new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), +// new(SemanticConvention.GEN_AI_TOKEN_TYPE, SemanticConvention.GEN_AI_TOKEN_TYPE_INPUT), +// new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), +// new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) +// ]); +// _tokenUsageMetric.Record(response.Usage?.CompletionTokens ?? 0, +// [ +// new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), +// new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), +// new(SemanticConvention.GEN_AI_TOKEN_TYPE, SemanticConvention.GEN_AI_TOKEN_TYPE_OUTPUT), +// new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), +// new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) +// ]); + +// _ = activity.SetTag(SemanticConvention.GEN_AI_USAGE_INPUT_TOKENS, usage.PromptTokens) +// .SetTag(SemanticConvention.GEN_AI_USAGE_OUTPUT_TOKENS, usage.CompletionTokens) +// .SetTag(SemanticConvention.GEN_AI_USAGE_TOTAL_TOKENS, usage.TotalTokens); + +// var reportModelUsageInformation = new OpenAITelemetryPluginReportModelUsageInformation +// { +// Model = response.Model, +// PromptTokens = usage.PromptTokens, +// CompletionTokens = usage.CompletionTokens +// }; +// var usagePerModel = _modelUsage.GetOrAdd(response.Model, model => []); +// usagePerModel.Add(reportModelUsageInformation); + +// if (!Configuration.IncludeCosts || Configuration.Prices is null) +// { +// Logger.LogDebug("Cost tracking is disabled or prices data is not available"); +// return; +// } + +// if (string.IsNullOrEmpty(response.Model)) +// { +// Logger.LogDebug("Response model is empty or null"); +// return; +// } + +// var (inputCost, outputCost) = Configuration.Prices.CalculateCost(response.Model, usage.PromptTokens, usage.CompletionTokens); + +// if (inputCost > 0) +// { +// var totalCost = inputCost + outputCost; +// _ = activity.SetTag(SemanticConvention.GEN_AI_USAGE_COST, totalCost); + +// Debug.Assert(_requestCostMetric is not null, "Cost histogram is not initialized"); +// Debug.Assert(_totalCostMetric is not null, "Total cost counter is not initialized"); + +// _requestCostMetric.Record(totalCost, +// [ +// new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), +// new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), +// new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), +// new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) +// ]); +// _totalCostMetric.Add(totalCost, +// [ +// new(SemanticConvention.GEN_AI_OPERATION_NAME, GetOperationName(request)), +// new(SemanticConvention.GEN_AI_SYSTEM, OpenAISystem), +// new(SemanticConvention.GEN_AI_REQUEST_MODEL, request.Model), +// new(SemanticConvention.GEN_AI_RESPONSE_MODEL, response.Model) +// ]); +// reportModelUsageInformation.Cost = totalCost; +// } +// else +// { +// Logger.LogDebug("Input cost is zero, skipping cost metrics recording"); +// } + +// Logger.LogTrace("RecordUsageMetrics() finished"); +// } + +// private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) +// { +// Logger.LogTrace("TryGetOpenAIRequest() called"); + +// request = null; + +// if (string.IsNullOrEmpty(content)) +// { +// Logger.LogDebug("Request content is empty or null"); +// return false; +// } + +// try +// { +// Logger.LogDebug("Checking if the request is an OpenAI request..."); + +// var rawRequest = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); + +// // Check for completion request (has "prompt", but not specific to image) +// if (rawRequest.TryGetProperty("prompt", out _) && +// !rawRequest.TryGetProperty("size", out _) && +// !rawRequest.TryGetProperty("n", out _)) +// { +// Logger.LogDebug("Request is a completion request"); +// request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); +// return true; +// } + +// // Chat completion request +// if (rawRequest.TryGetProperty("messages", out _)) +// { +// Logger.LogDebug("Request is a chat completion request"); +// request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); +// return true; +// } + +// // Embedding request +// if (rawRequest.TryGetProperty("input", out _) && +// rawRequest.TryGetProperty("model", out _) && +// !rawRequest.TryGetProperty("voice", out _)) +// { +// Logger.LogDebug("Request is an embedding request"); +// request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); +// return true; +// } + +// // Image generation request +// if (rawRequest.TryGetProperty("prompt", out _) && +// (rawRequest.TryGetProperty("size", out _) || rawRequest.TryGetProperty("n", out _))) +// { +// Logger.LogDebug("Request is an image generation request"); +// request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); +// return true; +// } + +// // Audio transcription request +// if (rawRequest.TryGetProperty("file", out _)) +// { +// Logger.LogDebug("Request is an audio transcription request"); +// request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); +// return true; +// } + +// // Audio speech synthesis request +// if (rawRequest.TryGetProperty("input", out _) && rawRequest.TryGetProperty("voice", out _)) +// { +// Logger.LogDebug("Request is an audio speech synthesis request"); +// request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); +// return true; +// } + +// // Fine-tuning request +// if (rawRequest.TryGetProperty("training_file", out _)) +// { +// Logger.LogDebug("Request is a fine-tuning request"); +// request = JsonSerializer.Deserialize(content, ProxyUtils.JsonSerializerOptions); +// return true; +// } + +// Logger.LogDebug("Request is not an OpenAI request."); +// return false; +// } +// catch (JsonException ex) +// { +// Logger.LogDebug(ex, "Failed to deserialize OpenAI request."); +// return false; +// } +// } + +// private static string GetOperationName(OpenAIRequest request) +// { +// if (request == null) +// { +// return "unknown"; +// } + +// return request switch +// { +// OpenAIChatCompletionRequest => "chat.completions", +// OpenAICompletionRequest => "completions", +// OpenAIEmbeddingRequest => "embeddings", +// OpenAIImageRequest => "images.generations", +// OpenAIAudioRequest => "audio.transcriptions", +// OpenAIAudioSpeechRequest => "audio.speech", +// OpenAIFineTuneRequest => "fine_tuning.jobs", +// _ => "unknown" +// }; +// } + +// private bool IsStreamingResponse(Response response) +// { +// Logger.LogTrace("{Method} called", nameof(IsStreamingResponse)); +// var contentType = response.Headers.FirstOrDefault(h => h.Name.Equals("content-type", StringComparison.OrdinalIgnoreCase))?.Value; +// if (string.IsNullOrEmpty(contentType)) +// { +// Logger.LogDebug("No content-type header found"); +// return false; +// } + +// var isStreamingResponse = contentType.Contains("text/event-stream", StringComparison.OrdinalIgnoreCase); +// Logger.LogDebug("IsStreamingResponse: {IsStreamingResponse}", isStreamingResponse); + +// Logger.LogTrace("{Method} finished", nameof(IsStreamingResponse)); +// return isStreamingResponse; +// } + +// private string GetBodyFromStreamingResponse(Response response) +// { +// Logger.LogTrace("{Method} called", nameof(GetBodyFromStreamingResponse)); + +// // default to the whole body +// var bodyString = response.BodyString; + +// var chunks = bodyString.Split("\n\n", StringSplitOptions.RemoveEmptyEntries); +// if (chunks.Length == 0) +// { +// Logger.LogDebug("No chunks found in the response body"); +// return bodyString; +// } + +// // check if the last chunk is `data: [DONE]` +// var lastChunk = chunks.Last().Trim(); +// if (lastChunk.Equals("data: [DONE]", StringComparison.OrdinalIgnoreCase)) +// { +// // get next to last chunk +// var chunk = chunks.Length > 1 ? chunks[^2].Trim() : string.Empty; +// if (chunk.StartsWith("data: ", StringComparison.OrdinalIgnoreCase)) +// { +// // remove the "data: " prefix +// bodyString = chunk["data: ".Length..].Trim(); +// Logger.LogDebug("Last chunk starts with 'data: ', using the last chunk as the body: {BodyString}", bodyString); +// } +// else +// { +// Logger.LogDebug("Last chunk does not start with 'data: ', using the whole body"); +// } +// } +// else +// { +// Logger.LogDebug("Last chunk is not `data: [DONE]`, using the whole body"); +// } + +// Logger.LogTrace("{Method} finished", nameof(GetBodyFromStreamingResponse)); +// return bodyString; +// } + +// public void Dispose() +// { +// _loader?.Dispose(); +// _activitySource?.Dispose(); +// _tracerProvider?.Dispose(); +// _meterProvider?.Dispose(); +// } +//} diff --git a/DevProxy.Plugins/Manipulation/RewritePlugin.cs b/DevProxy.Plugins/Manipulation/RewritePlugin.cs index 9e2d95ed..def7aba2 100644 --- a/DevProxy.Plugins/Manipulation/RewritePlugin.cs +++ b/DevProxy.Plugins/Manipulation/RewritePlugin.cs @@ -59,26 +59,26 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell await _loader.InitFileWatcherAsync(cancellationToken); } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return Task.FromResult(PluginResponse.Continue()); } if (Configuration.Rewrites is null || !Configuration.Rewrites.Any()) { - Logger.LogRequest("No rewrites configured", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("No rewrites configured", MessageType.Skipped, args.Request, args.RequestId); + return Task.FromResult(PluginResponse.Continue()); } - var request = e.Session.HttpClient.Request; + var originalUrl = args.Request.RequestUri?.ToString() ?? string.Empty; + var newUrl = originalUrl; + var wasRewritten = false; foreach (var rewrite in Configuration.Rewrites) { @@ -88,20 +88,41 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca continue; } - var newUrl = Regex.Replace(request.Url, rewrite.In.Url, rewrite.Out.Url, RegexOptions.IgnoreCase); + var rewrittenUrl = Regex.Replace(newUrl, rewrite.In.Url, rewrite.Out.Url, RegexOptions.IgnoreCase); - if (request.Url.Equals(newUrl, StringComparison.OrdinalIgnoreCase)) + if (newUrl.Equals(rewrittenUrl, StringComparison.OrdinalIgnoreCase)) { - Logger.LogRequest($"{rewrite.In?.Url}", MessageType.Skipped, new(e.Session)); + Logger.LogRequest($"{rewrite.In?.Url}", MessageType.Skipped, args.Request, args.RequestId); } else { - Logger.LogRequest($"{rewrite.In?.Url} > {newUrl}", MessageType.Processed, new(e.Session)); - request.Url = newUrl; + Logger.LogRequest($"{rewrite.In?.Url} > {rewrittenUrl}", MessageType.Processed, args.Request, args.RequestId); + newUrl = rewrittenUrl; + wasRewritten = true; } } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + if (wasRewritten && Uri.TryCreate(newUrl, UriKind.Absolute, out var newUri)) + { + var newRequest = new HttpRequestMessage(args.Request.Method, newUri); + + // Copy headers + foreach (var header in args.Request.Headers) + { + _ = newRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + + // Copy content and content headers + if (args.Request.Content != null) + { + newRequest.Content = args.Request.Content; + } + + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return Task.FromResult(PluginResponse.Continue(newRequest)); + } + + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return Task.FromResult(PluginResponse.Continue()); + }; } \ No newline at end of file diff --git a/DevProxy.Plugins/Mocking/AuthPlugin.cs b/DevProxy.Plugins/Mocking/AuthPlugin.cs index d688cd1f..c5000aed 100644 --- a/DevProxy.Plugins/Mocking/AuthPlugin.cs +++ b/DevProxy.Plugins/Mocking/AuthPlugin.cs @@ -14,12 +14,10 @@ using System.IdentityModel.Tokens.Jwt; using System.Net; using System.Security.Claims; +using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using System.Web; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Mocking; @@ -156,36 +154,28 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell Enabled = false; } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) - { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - if (e.ResponseState.HasBeenSet) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return Task.FromResult(PluginResponse.Continue()); } - if (!AuthorizeRequest(e.Session)) + if (!AuthorizeRequest(args.Request)) { - SendUnauthorizedResponse(e.Session); - e.ResponseState.HasBeenSet = true; + return Task.FromResult(PluginResponse.Respond(BuildUnauthorizedResponse(args.Request))); } else { - Logger.LogRequest("Request authorized", MessageType.Normal, new(e.Session)); + Logger.LogRequest("Request authorized", MessageType.Normal, args.Request); } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return Task.FromResult(PluginResponse.Continue()); + }; private async Task SetupOpenIdConnectConfigurationAsync(string metadataUrl) { @@ -201,20 +191,20 @@ private async Task SetupOpenIdConnectConfigurationAsync(string metadataUrl) } } - private bool AuthorizeRequest(SessionEventArgs session) + private bool AuthorizeRequest(HttpRequestMessage request) { Debug.Assert(Configuration is not null); Debug.Assert(Configuration.Type is not null); return Configuration.Type switch { - AuthPluginAuthType.ApiKey => AuthorizeApiKeyRequest(session), - AuthPluginAuthType.OAuth2 => AuthorizeOAuth2Request(session), + AuthPluginAuthType.ApiKey => AuthorizeApiKeyRequest(request), + AuthPluginAuthType.OAuth2 => AuthorizeOAuth2Request(request), _ => false, }; } - private bool AuthorizeApiKeyRequest(SessionEventArgs session) + private bool AuthorizeApiKeyRequest(HttpRequestMessage request) { Logger.LogDebug("Authorizing request using API key"); @@ -222,23 +212,23 @@ private bool AuthorizeApiKeyRequest(SessionEventArgs session) Debug.Assert(Configuration.ApiKey is not null); Debug.Assert(Configuration.ApiKey.AllowedKeys is not null); - var apiKey = GetApiKey(session); + var apiKey = GetApiKey(request); if (apiKey is null) { - Logger.LogRequest("401 Unauthorized. API key not found.", MessageType.Failed, new(session)); + Logger.LogRequest("401 Unauthorized. API key not found.", MessageType.Failed, request); return false; } var isKeyValid = Configuration.ApiKey.AllowedKeys.Contains(apiKey); if (!isKeyValid) { - Logger.LogRequest($"401 Unauthorized. API key {apiKey} is not allowed.", MessageType.Failed, new(session)); + Logger.LogRequest($"401 Unauthorized. API key {apiKey} is not allowed.", MessageType.Failed, request); } return isKeyValid; } - private bool AuthorizeOAuth2Request(SessionEventArgs session) + private bool AuthorizeOAuth2Request(HttpRequestMessage request) { Logger.LogDebug("Authorizing request using OAuth2"); @@ -247,7 +237,7 @@ private bool AuthorizeOAuth2Request(SessionEventArgs session) Debug.Assert(Configuration.OAuth2.MetadataUrl is not null); Debug.Assert(_openIdConnectConfiguration is not null); - var token = GetOAuth2Token(session); + var token = GetOAuth2Token(request); if (token is null) { return false; @@ -276,20 +266,20 @@ private bool AuthorizeOAuth2Request(SessionEventArgs session) try { var claimsPrincipal = handler.ValidateToken(token, validationParameters, out _); - return ValidateTenants(claimsPrincipal, session) && - ValidateApplications(claimsPrincipal, session) && - ValidatePrincipals(claimsPrincipal, session) && - ValidateRoles(claimsPrincipal, session) && - ValidateScopes(claimsPrincipal, session); + return ValidateTenants(claimsPrincipal, request) && + ValidateApplications(claimsPrincipal, request) && + ValidatePrincipals(claimsPrincipal, request) && + ValidateRoles(claimsPrincipal, request) && + ValidateScopes(claimsPrincipal, request); } catch (Exception ex) { - Logger.LogRequest($"401 Unauthorized. The specified token is not valid: {ex.Message}", MessageType.Failed, new(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not valid: {ex.Message}", MessageType.Failed, request); return false; } } - private bool ValidatePrincipals(ClaimsPrincipal claimsPrincipal, SessionEventArgs session) + private bool ValidatePrincipals(ClaimsPrincipal claimsPrincipal, HttpRequestMessage request) { Debug.Assert(Configuration is not null); Debug.Assert(Configuration.OAuth2 is not null); @@ -303,14 +293,14 @@ private bool ValidatePrincipals(ClaimsPrincipal claimsPrincipal, SessionEventArg var principalId = claimsPrincipal.FindFirst("http://schemas.microsoft.com/identity/claims/objectidentifier")?.Value; if (principalId is null) { - Logger.LogRequest("401 Unauthorized. The specified token doesn't have the oid claim.", MessageType.Failed, new(session)); + Logger.LogRequest("401 Unauthorized. The specified token doesn't have the oid claim.", MessageType.Failed, request); return false; } if (!Configuration.OAuth2.AllowedPrincipals.Contains(principalId)) { var principals = string.Join(", ", Configuration.OAuth2.AllowedPrincipals); - Logger.LogRequest($"401 Unauthorized. The specified token is not issued for an allowed principal. Allowed principals: {principals}, found: {principalId}", MessageType.Failed, new(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not issued for an allowed principal. Allowed principals: {principals}, found: {principalId}", MessageType.Failed, request); return false; } @@ -319,7 +309,7 @@ private bool ValidatePrincipals(ClaimsPrincipal claimsPrincipal, SessionEventArg return true; } - private bool ValidateApplications(ClaimsPrincipal claimsPrincipal, SessionEventArgs session) + private bool ValidateApplications(ClaimsPrincipal claimsPrincipal, HttpRequestMessage request) { Debug.Assert(Configuration is not null); Debug.Assert(Configuration.OAuth2 is not null); @@ -333,21 +323,21 @@ private bool ValidateApplications(ClaimsPrincipal claimsPrincipal, SessionEventA var tokenVersion = claimsPrincipal.FindFirst("ver")?.Value; if (tokenVersion is null) { - Logger.LogRequest("401 Unauthorized. The specified token doesn't have the ver claim.", MessageType.Failed, new(session)); + Logger.LogRequest("401 Unauthorized. The specified token doesn't have the ver claim.", MessageType.Failed, request); return false; } var appId = claimsPrincipal.FindFirst(tokenVersion == "1.0" ? "appid" : "azp")?.Value; if (appId is null) { - Logger.LogRequest($"401 Unauthorized. The specified token doesn't have the {(tokenVersion == "v1.0" ? "appid" : "azp")} claim.", MessageType.Failed, new(session)); + Logger.LogRequest($"401 Unauthorized. The specified token doesn't have the {(tokenVersion == "v1.0" ? "appid" : "azp")} claim.", MessageType.Failed, request); return false; } if (!Configuration.OAuth2.AllowedApplications.Contains(appId)) { var applications = string.Join(", ", Configuration.OAuth2.AllowedApplications); - Logger.LogRequest($"401 Unauthorized. The specified token is not issued by an allowed application. Allowed applications: {applications}, found: {appId}", MessageType.Failed, new(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not issued by an allowed application. Allowed applications: {applications}, found: {appId}", MessageType.Failed, request); return false; } @@ -356,7 +346,7 @@ private bool ValidateApplications(ClaimsPrincipal claimsPrincipal, SessionEventA return true; } - private bool ValidateTenants(ClaimsPrincipal claimsPrincipal, SessionEventArgs session) + private bool ValidateTenants(ClaimsPrincipal claimsPrincipal, HttpRequestMessage request) { Debug.Assert(Configuration is not null); Debug.Assert(Configuration.OAuth2 is not null); @@ -370,14 +360,14 @@ private bool ValidateTenants(ClaimsPrincipal claimsPrincipal, SessionEventArgs s var tenantId = claimsPrincipal.FindFirst("http://schemas.microsoft.com/identity/claims/tenantid")?.Value; if (tenantId is null) { - Logger.LogRequest("401 Unauthorized. The specified token doesn't have the tid claim.", MessageType.Failed, new(session)); + Logger.LogRequest("401 Unauthorized. The specified token doesn't have the tid claim.", MessageType.Failed, request); return false; } if (!Configuration.OAuth2.AllowedTenants.Contains(tenantId)) { var tenants = string.Join(", ", Configuration.OAuth2.AllowedTenants); - Logger.LogRequest($"401 Unauthorized. The specified token is not issued by an allowed tenant. Allowed tenants: {tenants}, found: {tenantId}", MessageType.Failed, new(session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not issued by an allowed tenant. Allowed tenants: {tenants}, found: {tenantId}", MessageType.Failed, request); return false; } @@ -386,7 +376,7 @@ private bool ValidateTenants(ClaimsPrincipal claimsPrincipal, SessionEventArgs s return true; } - private bool ValidateRoles(ClaimsPrincipal claimsPrincipal, SessionEventArgs session) + private bool ValidateRoles(ClaimsPrincipal claimsPrincipal, HttpRequestMessage request) { Debug.Assert(Configuration is not null); Debug.Assert(Configuration.OAuth2 is not null); @@ -404,7 +394,7 @@ private bool ValidateRoles(ClaimsPrincipal claimsPrincipal, SessionEventArgs ses var rolesRequired = string.Join(", ", Configuration.OAuth2.Roles); if (!Configuration.OAuth2.Roles.Any(r => HasPermission(r, rolesFromTheToken))) { - Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}", MessageType.Failed, new(session)); + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}", MessageType.Failed, request); return false; } @@ -413,7 +403,7 @@ private bool ValidateRoles(ClaimsPrincipal claimsPrincipal, SessionEventArgs ses return true; } - private bool ValidateScopes(ClaimsPrincipal claimsPrincipal, SessionEventArgs session) + private bool ValidateScopes(ClaimsPrincipal claimsPrincipal, HttpRequestMessage request) { Debug.Assert(Configuration is not null); Debug.Assert(Configuration.OAuth2 is not null); @@ -431,7 +421,7 @@ private bool ValidateScopes(ClaimsPrincipal claimsPrincipal, SessionEventArgs se var scopesRequired = string.Join(", ", Configuration.OAuth2.Scopes); if (!Configuration.OAuth2.Scopes.Any(s => HasPermission(s, scopesFromTheToken))) { - Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}", MessageType.Failed, new LoggingContext(session)); + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}", MessageType.Failed, request); return false; } @@ -440,29 +430,26 @@ private bool ValidateScopes(ClaimsPrincipal claimsPrincipal, SessionEventArgs se return true; } - private string? GetOAuth2Token(SessionEventArgs session) + private string? GetOAuth2Token(HttpRequestMessage request) { - var tokenParts = session.HttpClient.Request.Headers - .FirstOrDefault(h => h.Name.Equals("Authorization", StringComparison.OrdinalIgnoreCase)) - ?.Value - ?.Split(' '); - - if (tokenParts is null) + var authHeaderValue = request.Headers.Authorization?.ToString(); + if (authHeaderValue is null) { - Logger.LogRequest("401 Unauthorized. Authorization header not found.", MessageType.Failed, new(session)); + Logger.LogRequest("401 Unauthorized. Authorization header not found.", MessageType.Failed, request); return null; } + var tokenParts = authHeaderValue.Split(' '); if (tokenParts.Length != 2 || tokenParts[0] != "Bearer") { - Logger.LogRequest("401 Unauthorized. The specified token is not a valid Bearer token.", MessageType.Failed, new(session)); + Logger.LogRequest("401 Unauthorized. The specified token is not a valid Bearer token.", MessageType.Failed, request); return null; } return tokenParts[1]; } - private string? GetApiKey(SessionEventArgs session) + private string? GetApiKey(HttpRequestMessage request) { Debug.Assert(Configuration is not null); Debug.Assert(Configuration.ApiKey is not null); @@ -480,9 +467,9 @@ private bool ValidateScopes(ClaimsPrincipal claimsPrincipal, SessionEventArgs se Logger.LogDebug("Getting API key from parameter {Param} in {In}", parameter.Name, parameter.In); apiKey = parameter.In switch { - AuthPluginApiKeyIn.Header => GetApiKeyFromHeader(session.HttpClient.Request, parameter.Name), - AuthPluginApiKeyIn.Query => GetApiKeyFromQuery(session.HttpClient.Request, parameter.Name), - AuthPluginApiKeyIn.Cookie => GetApiKeyFromCookie(session.HttpClient.Request, parameter.Name), + AuthPluginApiKeyIn.Header => GetApiKeyFromHeader(request, parameter.Name), + AuthPluginApiKeyIn.Query => GetApiKeyFromQuery(request, parameter.Name), + AuthPluginApiKeyIn.Cookie => GetApiKeyFromCookie(request, parameter.Name), _ => null }; Logger.LogDebug("API key from parameter {Param} in {In}: {ApiKey}", parameter.Name, parameter.In, apiKey ?? "(not found)"); @@ -496,7 +483,7 @@ private bool ValidateScopes(ClaimsPrincipal claimsPrincipal, SessionEventArgs se return apiKey; } - private static void SendUnauthorizedResponse(SessionEventArgs e) + private static HttpResponseMessage BuildUnauthorizedResponse(HttpRequestMessage request) { var body = new { @@ -505,19 +492,18 @@ private static void SendUnauthorizedResponse(SessionEventArgs e) message = "Unauthorized" } }; - SendJsonResponse(JsonSerializer.Serialize(body, ProxyUtils.JsonSerializerOptions), HttpStatusCode.Unauthorized, e); - } - private static void SendJsonResponse(string body, HttpStatusCode statusCode, SessionEventArgs e) - { - var headers = new List { - new("content-type", "application/json; charset=utf-8") + var response = new HttpResponseMessage(HttpStatusCode.Unauthorized) + { + Content = new StringContent(JsonSerializer.Serialize(body, ProxyUtils.JsonSerializerOptions), Encoding.UTF8, "application/json") }; - if (e.HttpClient.Request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))) + + if (request.Headers.TryGetValues("Origin", out var _)) { - headers.Add(new("access-control-allow-origin", "*")); + _ = response.Headers.TryAddWithoutValidation("access-control-allow-origin", "*"); } - e.GenericResponse(body, statusCode, headers); + + return response; } private static bool HasPermission(string permission, string permissionString) @@ -531,9 +517,15 @@ private static bool HasPermission(string permission, string permissionString) return permissions.Contains(permission, StringComparer.OrdinalIgnoreCase); } - private static string? GetApiKeyFromCookie(Request request, string cookieName) + private static string? GetApiKeyFromCookie(HttpRequestMessage request, string cookieName) { - var cookies = ParseCookies(request.Headers.FirstOrDefault(h => h.Name.Equals("Cookie", StringComparison.OrdinalIgnoreCase))?.Value); + if (!request.Headers.TryGetValues("Cookie", out var cookieValues)) + { + return null; + } + + var cookieHeader = cookieValues.FirstOrDefault(); + var cookies = ParseCookies(cookieHeader); if (cookies is null) { return null; @@ -562,14 +554,18 @@ private static bool HasPermission(string permission, string permissionString) return cookies; } - private static string? GetApiKeyFromQuery(Request request, string paramName) + private static string? GetApiKeyFromQuery(HttpRequestMessage request, string paramName) { - var queryParameters = HttpUtility.ParseQueryString(request.RequestUri.Query); + var queryParameters = HttpUtility.ParseQueryString(request.RequestUri?.Query ?? string.Empty); return queryParameters[paramName]; } - private static string? GetApiKeyFromHeader(Request request, string headerName) + private static string? GetApiKeyFromHeader(HttpRequestMessage request, string headerName) { - return request.Headers.FirstOrDefault(h => h.Name == headerName)?.Value; + if (request.Headers.TryGetValues(headerName, out var values)) + { + return values.FirstOrDefault(); + } + return null; } } diff --git a/DevProxy.Plugins/Mocking/CrudApiPlugin.cs b/DevProxy.Plugins/Mocking/CrudApiPlugin.cs index 979b14e6..fe4611c3 100644 --- a/DevProxy.Plugins/Mocking/CrudApiPlugin.cs +++ b/DevProxy.Plugins/Mocking/CrudApiPlugin.cs @@ -17,11 +17,9 @@ using System.Diagnostics; using System.Net; using System.Security.Claims; +using System.Text; using System.Text.Json.Serialization; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Mocking; @@ -130,61 +128,47 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell await SetupOpenIdConnectConfigurationAsync(); } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - var request = e.Session.HttpClient.Request; - var state = e.ResponseState; - - if (!e.HasRequestUrlMatch(UrlsToWatch)) - { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; - } - if (e.ResponseState.HasBeenSet) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - if (IsCORSPreflightRequest(request) && Configuration.EnableCORS) + if (IsCORSPreflightRequest(args.Request) && Configuration.EnableCORS) { - SendEmptyResponse(HttpStatusCode.NoContent, e.Session); - Logger.LogRequest("CORS preflight request", MessageType.Mocked, new LoggingContext(e.Session)); - return Task.CompletedTask; + var corsResponse = BuildEmptyResponse(HttpStatusCode.NoContent, args.Request); + Logger.LogRequest("CORS preflight request", MessageType.Mocked, args.Request); + return PluginResponse.Respond(corsResponse); } - if (!AuthorizeRequest(e)) + if (!AuthorizeRequest(args.Request)) { - SendUnauthorizedResponse(e.Session); - state.HasBeenSet = true; - return Task.CompletedTask; + return PluginResponse.Respond(BuildUnauthorizedResponse(args.Request)); } - var actionAndParams = GetMatchingActionHandler(request); + var actionAndParams = GetMatchingActionHandler(args.Request); if (actionAndParams is not null) { - if (!AuthorizeRequest(e, actionAndParams.Value.action)) + if (!AuthorizeRequest(args.Request, actionAndParams.Value.action)) { - SendUnauthorizedResponse(e.Session); - state.HasBeenSet = true; - return Task.CompletedTask; + return PluginResponse.Respond(BuildUnauthorizedResponse(args.Request)); } - actionAndParams.Value.handler(e.Session, actionAndParams.Value.action, actionAndParams.Value.parameters); - state.HasBeenSet = true; + var response = await actionAndParams.Value.handler(args.Request, actionAndParams.Value.action, actionAndParams.Value.parameters, cancellationToken); + return PluginResponse.Respond(response); } else { - Logger.LogRequest("Did not match any action", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("Did not match any action", MessageType.Skipped, args.Request); } - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return PluginResponse.Continue(); + }; private async Task SetupOpenIdConnectConfigurationAsync() { @@ -221,7 +205,7 @@ private void LoadData() } } - private (Action> handler, CrudApiAction action, IDictionary parameters)? GetMatchingActionHandler(Request request) + private (Func, CancellationToken, Task> handler, CrudApiAction action, IDictionary parameters)? GetMatchingActionHandler(HttpRequestMessage request) { if (Configuration.Actions is null || !Configuration.Actions.Any()) @@ -238,14 +222,14 @@ private void LoadData() var parameters = new Dictionary(); var action = Configuration.Actions.FirstOrDefault(action => { - if (action.Method != request.Method) + if (action.Method != request.Method.Method) { return false; } var absoluteActionUrl = (Configuration.BaseUrl + action.Url).Replace("//", "/", 8); - if (absoluteActionUrl == request.Url) + if (absoluteActionUrl == request.RequestUri?.ToString()) { return true; } @@ -259,7 +243,7 @@ private void LoadData() // convert parameters into named regex groups var urlRegex = Regex.Replace(Regex.Escape(absoluteActionUrl).Replace("\\{", "{", StringComparison.OrdinalIgnoreCase), "({[^}]+})", parameterMatchEvaluator); - var match = Regex.Match(request.Url, urlRegex); + var match = Regex.Match(request.RequestUri?.ToString() ?? string.Empty, urlRegex); if (!match.Success) { return false; @@ -283,31 +267,36 @@ private void LoadData() return (handler: action.Action switch { - CrudApiActionType.Create => Create, - CrudApiActionType.GetAll => GetAll, - CrudApiActionType.GetOne => GetOne, - CrudApiActionType.GetMany => GetMany, - CrudApiActionType.Merge => Merge, - CrudApiActionType.Update => Update, - CrudApiActionType.Delete => Delete, + CrudApiActionType.Create => CreateAsync, + CrudApiActionType.GetAll => GetAllAsync, + CrudApiActionType.GetOne => GetOneAsync, + CrudApiActionType.GetMany => GetManyAsync, + CrudApiActionType.Merge => MergeAsync, + CrudApiActionType.Update => UpdateAsync, + CrudApiActionType.Delete => DeleteAsync, _ => throw new NotImplementedException() }, action, parameters); } - private void AddCORSHeaders(Request request, List headers) + private void AddCORSHeaders(HttpRequestMessage request, HttpResponseMessage response) { - var origin = request.Headers.FirstOrDefault(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase))?.Value; + if (!request.Headers.TryGetValues("Origin", out var originValues)) + { + return; + } + + var origin = originValues.FirstOrDefault(); if (string.IsNullOrEmpty(origin)) { return; } - headers.Add(new HttpHeader("access-control-allow-origin", origin)); + _ = response.Headers.TryAddWithoutValidation("access-control-allow-origin", origin); if (Configuration.EntraAuthConfig is not null || Configuration.Actions.Any(a => a.Auth == CrudApiAuthType.Entra)) { - headers.Add(new HttpHeader("access-control-allow-headers", "authorization, content-type")); + _ = response.Headers.TryAddWithoutValidation("access-control-allow-headers", "authorization, content-type"); } var methods = string.Join(", ", Configuration.Actions @@ -315,10 +304,10 @@ private void AddCORSHeaders(Request request, List headers) .Select(a => a.Method) .Distinct()); - headers.Add(new HttpHeader("access-control-allow-methods", methods)); + _ = response.Headers.TryAddWithoutValidation("access-control-allow-methods", methods); } - private bool AuthorizeRequest(ProxyRequestArgs e, CrudApiAction? action = null) + private bool AuthorizeRequest(HttpRequestMessage request, CrudApiAction? action = null) { var authType = action is null ? Configuration.Auth : action.Auth; var authConfig = action is null ? Configuration.EntraAuthConfig : action.EntraAuthConfig; @@ -334,19 +323,19 @@ private bool AuthorizeRequest(ProxyRequestArgs e, CrudApiAction? action = null) Debug.Assert(authConfig is not null, "EntraAuthConfig is null when auth is required."); - var token = e.Session.HttpClient.Request.Headers.FirstOrDefault(h => h.Name.Equals("Authorization", StringComparison.OrdinalIgnoreCase))?.Value; + var authHeaderValue = request.Headers.Authorization?.ToString(); // is there a token - if (string.IsNullOrEmpty(token)) + if (string.IsNullOrEmpty(authHeaderValue)) { - Logger.LogRequest("401 Unauthorized. No token found on the request.", MessageType.Failed, new(e.Session)); + Logger.LogRequest("401 Unauthorized. No token found on the request.", MessageType.Failed, request); return false; } // does the token has a valid format - var tokenHeaderParts = token.Split(' '); + var tokenHeaderParts = authHeaderValue.Split(' '); if (tokenHeaderParts.Length != 2 || tokenHeaderParts[0] != "Bearer") { - Logger.LogRequest("401 Unauthorized. The specified token is not a valid Bearer token.", MessageType.Failed, new(e.Session)); + Logger.LogRequest("401 Unauthorized. The specified token is not a valid Bearer token.", MessageType.Failed, request); return false; } @@ -386,7 +375,7 @@ private bool AuthorizeRequest(ProxyRequestArgs e, CrudApiAction? action = null) { var rolesRequired = string.Join(", ", authConfig.Roles); - Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}", MessageType.Failed, new(e.Session)); + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary role(s). Required one of: {rolesRequired}, found: {rolesFromTheToken}", MessageType.Failed, request); return false; } @@ -402,7 +391,7 @@ private bool AuthorizeRequest(ProxyRequestArgs e, CrudApiAction? action = null) { var scopesRequired = string.Join(", ", authConfig.Scopes); - Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}", MessageType.Failed, new(e.Session)); + Logger.LogRequest($"401 Unauthorized. The specified token does not have the necessary scope(s). Required one of: {scopesRequired}, found: {scopesFromTheToken}", MessageType.Failed, request); return false; } @@ -411,14 +400,14 @@ private bool AuthorizeRequest(ProxyRequestArgs e, CrudApiAction? action = null) } catch (Exception ex) { - Logger.LogRequest($"401 Unauthorized. The specified token is not valid: {ex.Message}", MessageType.Failed, new(e.Session)); + Logger.LogRequest($"401 Unauthorized. The specified token is not valid: {ex.Message}", MessageType.Failed, request); return false; } return true; } - private void SendUnauthorizedResponse(SessionEventArgs e) + private HttpResponseMessage BuildUnauthorizedResponse(HttpRequestMessage request) { var body = new { @@ -427,10 +416,10 @@ private void SendUnauthorizedResponse(SessionEventArgs e) message = "Unauthorized" } }; - SendJsonResponse(System.Text.Json.JsonSerializer.Serialize(body, ProxyUtils.JsonSerializerOptions), HttpStatusCode.Unauthorized, e); + return BuildJsonResponse(System.Text.Json.JsonSerializer.Serialize(body, ProxyUtils.JsonSerializerOptions), HttpStatusCode.Unauthorized, request); } - private void SendNotFoundResponse(SessionEventArgs e) + private HttpResponseMessage BuildNotFoundResponse(HttpRequestMessage request) { var body = new { @@ -439,157 +428,177 @@ private void SendNotFoundResponse(SessionEventArgs e) message = "Not found" } }; - SendJsonResponse(System.Text.Json.JsonSerializer.Serialize(body, ProxyUtils.JsonSerializerOptions), HttpStatusCode.NotFound, e); + return BuildJsonResponse(System.Text.Json.JsonSerializer.Serialize(body, ProxyUtils.JsonSerializerOptions), HttpStatusCode.NotFound, request); } - private void SendEmptyResponse(HttpStatusCode statusCode, SessionEventArgs e) + private HttpResponseMessage BuildEmptyResponse(HttpStatusCode statusCode, HttpRequestMessage request) { - var headers = new List(); - AddCORSHeaders(e.HttpClient.Request, headers); - e.GenericResponse("", statusCode, headers); + var response = new HttpResponseMessage(statusCode) + { + Content = new StringContent(string.Empty, Encoding.UTF8, "text/plain") + }; + AddCORSHeaders(request, response); + return response; } - private void SendJsonResponse(string body, HttpStatusCode statusCode, SessionEventArgs e) + private HttpResponseMessage BuildJsonResponse(string body, HttpStatusCode statusCode, HttpRequestMessage request) { - var headers = new List { - new("content-type", "application/json; charset=utf-8") + var response = new HttpResponseMessage(statusCode) + { + Content = new StringContent(body, Encoding.UTF8, "application/json") }; - AddCORSHeaders(e.HttpClient.Request, headers); - e.GenericResponse(body, statusCode, headers); + AddCORSHeaders(request, response); + return response; } - private void GetAll(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + private Task GetAllAsync(HttpRequestMessage request, CrudApiAction action, IDictionary parameters, CancellationToken cancellationToken) { - SendJsonResponse(JsonConvert.SerializeObject(_data, Formatting.Indented), HttpStatusCode.OK, e); - Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(_data, Formatting.Indented), HttpStatusCode.OK, request); + Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, request); + return Task.FromResult(response); } - private void GetOne(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + private Task GetOneAsync(HttpRequestMessage request, CrudApiAction action, IDictionary parameters, CancellationToken cancellationToken) { try { var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); if (item is null) { - SendNotFoundResponse(e); - Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new(e)); - return; + var response = BuildNotFoundResponse(request); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, request); + return Task.FromResult(response); } - SendJsonResponse(JsonConvert.SerializeObject(item, Formatting.Indented), HttpStatusCode.OK, e); - Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, new(e)); + var successResponse = BuildJsonResponse(JsonConvert.SerializeObject(item, Formatting.Indented), HttpStatusCode.OK, request); + Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, request); + return Task.FromResult(successResponse); } catch (Exception ex) { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, request); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, request); + return Task.FromResult(response); } } - private void GetMany(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + private Task GetManyAsync(HttpRequestMessage request, CrudApiAction action, IDictionary parameters, CancellationToken cancellationToken) { try { var items = (_data?.SelectTokens(ReplaceParams(action.Query, parameters))) ?? []; - SendJsonResponse(JsonConvert.SerializeObject(items, Formatting.Indented), HttpStatusCode.OK, e); - Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(items, Formatting.Indented), HttpStatusCode.OK, request); + Logger.LogRequest($"200 {action.Url}", MessageType.Mocked, request); + return Task.FromResult(response); } catch (Exception ex) { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, request); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, request); + return Task.FromResult(response); } } - private void Create(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + private async Task CreateAsync(HttpRequestMessage request, CrudApiAction action, IDictionary parameters, CancellationToken cancellationToken) { try { - var data = JObject.Parse(e.HttpClient.Request.BodyString); + var bodyString = request.Content is not null ? await request.Content.ReadAsStringAsync(cancellationToken) : string.Empty; + var data = JObject.Parse(bodyString); _data?.Add(data); - SendJsonResponse(JsonConvert.SerializeObject(data, Formatting.Indented), HttpStatusCode.Created, e); - Logger.LogRequest($"201 {action.Url}", MessageType.Mocked, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(data, Formatting.Indented), HttpStatusCode.Created, request); + Logger.LogRequest($"201 {action.Url}", MessageType.Mocked, request); + return response; } catch (Exception ex) { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, request); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, request); + return response; } } - private void Merge(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + private async Task MergeAsync(HttpRequestMessage request, CrudApiAction action, IDictionary parameters, CancellationToken cancellationToken) { try { var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); if (item is null) { - SendNotFoundResponse(e); - Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new(e)); - return; + var response = BuildNotFoundResponse(request); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, request); + return response; } - var update = JObject.Parse(e.HttpClient.Request.BodyString); + var bodyString = request.Content is not null ? await request.Content.ReadAsStringAsync(cancellationToken) : string.Empty; + var update = JObject.Parse(bodyString); ((JContainer)item).Merge(update); - SendEmptyResponse(HttpStatusCode.NoContent, e); - Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, new(e)); + var successResponse = BuildEmptyResponse(HttpStatusCode.NoContent, request); + Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, request); + return successResponse; } catch (Exception ex) { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, request); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, request); + return response; } } - private void Update(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + private async Task UpdateAsync(HttpRequestMessage request, CrudApiAction action, IDictionary parameters, CancellationToken cancellationToken) { try { var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); if (item is null) { - SendNotFoundResponse(e); - Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new(e)); - return; + var response = BuildNotFoundResponse(request); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, request); + return response; } - var update = JObject.Parse(e.HttpClient.Request.BodyString); + var bodyString = request.Content is not null ? await request.Content.ReadAsStringAsync(cancellationToken) : string.Empty; + var update = JObject.Parse(bodyString); ((JContainer)item).Replace(update); - SendEmptyResponse(HttpStatusCode.NoContent, e); - Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, new(e)); + var successResponse = BuildEmptyResponse(HttpStatusCode.NoContent, request); + Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, request); + return successResponse; } catch (Exception ex) { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, request); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, request); + return response; } } - private void Delete(SessionEventArgs e, CrudApiAction action, IDictionary parameters) + private Task DeleteAsync(HttpRequestMessage request, CrudApiAction action, IDictionary parameters, CancellationToken cancellationToken) { try { var item = _data?.SelectToken(ReplaceParams(action.Query, parameters)); if (item is null) { - SendNotFoundResponse(e); - Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, new(e)); - return; + var response = BuildNotFoundResponse(request); + Logger.LogRequest($"404 {action.Url}", MessageType.Mocked, request); + return Task.FromResult(response); } item.Remove(); - SendEmptyResponse(HttpStatusCode.NoContent, e); - Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, new(e)); + var successResponse = BuildEmptyResponse(HttpStatusCode.NoContent, request); + Logger.LogRequest($"204 {action.Url}", MessageType.Mocked, request); + return Task.FromResult(successResponse); } catch (Exception ex) { - SendJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, e); - Logger.LogRequest($"500 {action.Url}", MessageType.Failed, new(e)); + var response = BuildJsonResponse(JsonConvert.SerializeObject(ex, Formatting.Indented), HttpStatusCode.InternalServerError, request); + Logger.LogRequest($"500 {action.Url}", MessageType.Failed, request); + return Task.FromResult(response); } } - private static bool IsCORSPreflightRequest(Request request) + private static bool IsCORSPreflightRequest(HttpRequestMessage request) { - return request.Method == "OPTIONS" && - request.Headers.Any(h => h.Name.Equals("Origin", StringComparison.OrdinalIgnoreCase)); + return request.Method == HttpMethod.Options && + request.Headers.TryGetValues("Origin", out var _); } private static bool HasPermission(string permission, string permissionString) diff --git a/DevProxy.Plugins/Mocking/EntraMockResponsePlugin.cs b/DevProxy.Plugins/Mocking/EntraMockResponsePlugin.cs index be1667f3..e0fbf91f 100644 --- a/DevProxy.Plugins/Mocking/EntraMockResponsePlugin.cs +++ b/DevProxy.Plugins/Mocking/EntraMockResponsePlugin.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using DevProxy.Abstractions.Models; +using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Proxy; using DevProxy.Abstractions.Utils; using Microsoft.Extensions.Configuration; @@ -41,30 +42,32 @@ public sealed class EntraMockResponsePlugin( ISet urlsToWatch, X509Certificate2 certificate, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : MockResponsePlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { private string? lastNonce; public override string Name => nameof(EntraMockResponsePlugin); // Running on POST requests with a body - protected override void ProcessMockResponse(ref byte[] body, IList headers, ProxyRequestArgs e, MockResponse? matchingResponse) + protected override void ProcessMockResponse(ref byte[] body, IList headers, HttpRequestMessage request, MockResponse? matchingResponse) { - ArgumentNullException.ThrowIfNull(e); + ArgumentNullException.ThrowIfNull(request); - base.ProcessMockResponse(ref body, headers, e, matchingResponse); + base.ProcessMockResponse(ref body, headers, request, matchingResponse); var bodyString = Encoding.UTF8.GetString(body); var changed = false; - StoreLastNonce(e); - UpdateMsalStateInBody(ref bodyString, e, ref changed); + StoreLastNonce(request); + UpdateMsalStateInBody(ref bodyString, request, ref changed); UpdateIdToken(ref bodyString, ref changed); UpdateDevProxyKeyId(ref bodyString, ref changed); UpdateDevProxyCertificateChain(ref bodyString, ref changed); @@ -76,12 +79,12 @@ protected override void ProcessMockResponse(ref byte[] body, IList headers, ProxyRequestArgs e, MockResponse? matchingResponse) + protected override void ProcessMockResponse(ref string? body, IList headers, HttpRequestMessage request, MockResponse? matchingResponse) { - base.ProcessMockResponse(ref body, headers, e, matchingResponse); + base.ProcessMockResponse(ref body, headers, request, matchingResponse); - StoreLastNonce(e); - UpdateMsalStateInHeaders(headers, e); + StoreLastNonce(request); + UpdateMsalStateInHeaders(headers, request); } private void UpdateDevProxyCertificateChain(ref string bodyString, ref bool changed) @@ -107,11 +110,11 @@ private void UpdateDevProxyKeyId(ref string bodyString, ref bool changed) changed = true; } - private void StoreLastNonce(ProxyRequestArgs e) + private void StoreLastNonce(HttpRequestMessage request) { - if (e.Session.HttpClient.Request.RequestUri.Query.Contains("nonce=", StringComparison.OrdinalIgnoreCase)) + if (request.RequestUri?.Query.Contains("nonce=", StringComparison.OrdinalIgnoreCase) == true) { - var queryString = HttpUtility.ParseQueryString(e.Session.HttpClient.Request.RequestUri.Query); + var queryString = HttpUtility.ParseQueryString(request.RequestUri.Query); lastNonce = queryString["nonce"]; } } @@ -177,30 +180,30 @@ private static string PadBase64(string base64) return base64 + padding; } - private static void UpdateMsalStateInHeaders(IList headers, ProxyRequestArgs e) + private static void UpdateMsalStateInHeaders(IList headers, HttpRequestMessage request) { var locationHeader = headers.FirstOrDefault(h => h.Name.Equals("Location", StringComparison.OrdinalIgnoreCase)); if (locationHeader is null || - !e.Session.HttpClient.Request.RequestUri.Query.Contains("state=", StringComparison.OrdinalIgnoreCase)) + request.RequestUri?.Query.Contains("state=", StringComparison.OrdinalIgnoreCase) != true) { return; } - var queryString = HttpUtility.ParseQueryString(e.Session.HttpClient.Request.RequestUri.Query); + var queryString = HttpUtility.ParseQueryString(request.RequestUri.Query); var msalState = queryString["state"]; locationHeader.Value = locationHeader.Value.Replace("state=@dynamic", $"state={msalState}", StringComparison.OrdinalIgnoreCase); } - private static void UpdateMsalStateInBody(ref string body, ProxyRequestArgs e, ref bool changed) + private static void UpdateMsalStateInBody(ref string body, HttpRequestMessage request, ref bool changed) { if (!body.Contains("state=@dynamic", StringComparison.OrdinalIgnoreCase) || - !e.Session.HttpClient.Request.RequestUri.Query.Contains("state=", StringComparison.OrdinalIgnoreCase)) + request.RequestUri?.Query.Contains("state=", StringComparison.OrdinalIgnoreCase) != true) { return; } - var queryString = HttpUtility.ParseQueryString(e.Session.HttpClient.Request.RequestUri.Query); + var queryString = HttpUtility.ParseQueryString(request.RequestUri.Query); var msalState = queryString["state"]; body = body.Replace("state=@dynamic", $"state={msalState}", StringComparison.OrdinalIgnoreCase); changed = true; diff --git a/DevProxy.Plugins/Mocking/GraphMockResponsePlugin.cs b/DevProxy.Plugins/Mocking/GraphMockResponsePlugin.cs index d3d0f717..6281a673 100644 --- a/DevProxy.Plugins/Mocking/GraphMockResponsePlugin.cs +++ b/DevProxy.Plugins/Mocking/GraphMockResponsePlugin.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using DevProxy.Abstractions.Models; +using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Proxy; using DevProxy.Abstractions.Utils; using DevProxy.Plugins.Behavior; @@ -10,9 +11,9 @@ using Microsoft.Extensions.Logging; using System.Globalization; using System.Net; +using System.Text; using System.Text.Json; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Mocking; @@ -21,40 +22,44 @@ public class GraphMockResponsePlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : MockResponsePlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { public override string Name => nameof(GraphMockResponsePlugin); - public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); - - ArgumentNullException.ThrowIfNull(e); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); if (Configuration.NoMocks) { - Logger.LogRequest("Mocks are disabled", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Mocks are disabled", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - if (!ProxyUtils.IsGraphBatchUrl(e.Session.HttpClient.Request.RequestUri)) + if (args.Request.RequestUri is null || !ProxyUtils.IsGraphBatchUrl(args.Request.RequestUri)) { // not a batch request, use the basic mock functionality - await base.BeforeRequestAsync(e, cancellationToken); - return; + return await base.OnRequestAsync!(args, cancellationToken); + } + + if (args.Request.Content is null) + { + return await base.OnRequestAsync!(args, cancellationToken); } - var batch = JsonSerializer.Deserialize(e.Session.HttpClient.Request.BodyString, ProxyUtils.JsonSerializerOptions); + var requestBody = await args.Request.Content.ReadAsStringAsync(cancellationToken); + var batch = JsonSerializer.Deserialize(requestBody, ProxyUtils.JsonSerializerOptions); if (batch is null) { - await base.BeforeRequestAsync(e, cancellationToken); - return; + return await base.OnRequestAsync!(args, cancellationToken); } var responses = new List(); @@ -64,15 +69,17 @@ public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationTo var requestId = Guid.NewGuid().ToString(); var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); var headers = ProxyUtils - .BuildGraphResponseHeaders(e.Session.HttpClient.Request, requestId, requestDate); + .BuildGraphResponseHeaders(args.Request, requestId, requestDate); - if (e.SessionData.TryGetValue(nameof(RateLimitingPlugin), out var pluginData) && + // Check for rate limiting headers from RateLimitingPlugin using new storage API + var requestData = ProxyStorage.GetRequestData(args.RequestId); + if (requestData.TryGetValue(nameof(RateLimitingPlugin), out var pluginData) && pluginData is List rateLimitingHeaders) { ProxyUtils.MergeHeaders(headers, rateLimitingHeaders); } - var mockResponse = GetMatchingMockResponse(request, e.Session.HttpClient.Request.RequestUri); + var mockResponse = GetMatchingMockResponse(request, args.Request.RequestUri); if (mockResponse == null) { response = new() @@ -90,7 +97,7 @@ public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationTo } }; - Logger.LogRequest($"502 {request.Url}", MessageType.Mocked, new(e.Session)); + Logger.LogRequest($"502 {request.Url}", MessageType.Mocked, args.Request); } else { @@ -148,7 +155,7 @@ public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationTo Body = body }; - Logger.LogRequest($"{mockResponse.Response?.StatusCode ?? 200} {mockResponse.Request?.Url}", MessageType.Mocked, new LoggingContext(e.Session)); + Logger.LogRequest($"{mockResponse.Response?.StatusCode ?? 200} {mockResponse.Request?.Url}", MessageType.Mocked, args.Request); } responses.Add(response); @@ -156,19 +163,27 @@ public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationTo var batchRequestId = Guid.NewGuid().ToString(); var batchRequestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); - var batchHeaders = ProxyUtils.BuildGraphResponseHeaders(e.Session.HttpClient.Request, batchRequestId, batchRequestDate); + var batchHeaders = ProxyUtils.BuildGraphResponseHeaders(args.Request, batchRequestId, batchRequestDate); var batchResponse = new GraphBatchResponsePayload { Responses = [.. responses] }; var batchResponseString = JsonSerializer.Serialize(batchResponse, ProxyUtils.JsonSerializerOptions); - ProcessMockResponse(ref batchResponseString, batchHeaders, e, null); - e.Session.GenericResponse(batchResponseString ?? string.Empty, HttpStatusCode.OK, batchHeaders.Select(h => new HttpHeader(h.Name, h.Value))); - Logger.LogRequest($"200 {e.Session.HttpClient.Request.RequestUri}", MessageType.Mocked, new(e.Session)); - e.ResponseState.HasBeenSet = true; + ProcessMockResponse(ref batchResponseString, batchHeaders, args.Request, null); - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - } + var httpResponse = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(batchResponseString ?? string.Empty, Encoding.UTF8, "application/json") + }; + + foreach (var header in batchHeaders) + { + _ = httpResponse.Headers.TryAddWithoutValidation(header.Name, header.Value); + } + + Logger.LogRequest($"200 {args.Request.RequestUri}", MessageType.Mocked, args.Request); + return PluginResponse.Respond(httpResponse); + }; protected MockResponse? GetMatchingMockResponse(GraphBatchRequestPayloadRequest request, Uri batchRequestUri) { diff --git a/DevProxy.Plugins/Mocking/MockResponsePlugin.cs b/DevProxy.Plugins/Mocking/MockResponsePlugin.cs index 72ad2c37..87e60e11 100644 --- a/DevProxy.Plugins/Mocking/MockResponsePlugin.cs +++ b/DevProxy.Plugins/Mocking/MockResponsePlugin.cs @@ -6,7 +6,6 @@ using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Proxy; using DevProxy.Abstractions.Utils; -using DevProxy.Plugins.Behavior; using DevProxy.Plugins.Models; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; @@ -21,8 +20,6 @@ using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Mocking; @@ -44,7 +41,8 @@ public class MockResponsePlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BasePlugin( httpClient, logger, @@ -64,6 +62,8 @@ public class MockResponsePlugin( private Argument>? _httpResponseFilesArgument; private Option? _httpResponseMocksFileNameOption; + protected IProxyStorage ProxyStorage => proxyStorage; + public override string Name => nameof(MockResponsePlugin); public override async Task InitializeAsync(InitArgs e, CancellationToken cancellationToken) @@ -154,43 +154,39 @@ public override void OptionsLoaded(OptionsLoadedArgs e) ValidateMocks(); } - public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); - - ArgumentNullException.ThrowIfNull(e); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - var request = e.Session.HttpClient.Request; - var state = e.ResponseState; if (Configuration.NoMocks) { - Logger.LogRequest("Mocks disabled", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("Mocks disabled", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - if (!e.ShouldExecute(UrlsToWatch)) + + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); - return Task.CompletedTask; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - var matchingResponse = GetMatchingMockResponse(request); + var matchingResponse = await GetMatchingMockResponse(args.Request); if (matchingResponse is not null) { // we need to clone the response so that we're not modifying // the original that might be used in other requests var clonedResponse = (MockResponse)matchingResponse.Clone(); - ProcessMockResponseInternal(e, clonedResponse); - state.HasBeenSet = true; - return Task.CompletedTask; + var httpResponse = ProcessMockResponseInternal(args.Request, clonedResponse, args.RequestId); + return PluginResponse.Respond(httpResponse); } else if (Configuration.BlockUnmockedRequests) { - ProcessMockResponseInternal(e, new() + var errorResponse = ProcessMockResponseInternal(args.Request, new() { Request = new() { - Url = request.Url, - Method = request.Method ?? "" + Url = args.Request.RequestUri!.ToString(), + Method = args.Request.Method.Method }, Response = new() { @@ -198,25 +194,24 @@ public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken ca Body = new GraphErrorResponseBody(new() { Code = "Bad Gateway", - Message = $"No mock response found for {request.Method} {request.Url}" + Message = $"No mock response found for {args.Request.Method} {args.Request.RequestUri}" }) } - }); - state.HasBeenSet = true; - return Task.CompletedTask; + }, args.RequestId); + return PluginResponse.Respond(errorResponse); } - Logger.LogRequest("No matching mock response found", MessageType.Skipped, new(e.Session)); + Logger.LogRequest("No matching mock response found", MessageType.Skipped, args.Request); - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - return Task.CompletedTask; - } + Logger.LogTrace("Left {Name}", nameof(OnRequestAsync)); + return PluginResponse.Continue(); + }; - protected virtual void ProcessMockResponse(ref byte[] body, IList headers, ProxyRequestArgs e, MockResponse? matchingResponse) + protected virtual void ProcessMockResponse(ref byte[] body, IList headers, HttpRequestMessage request, MockResponse? matchingResponse) { } - protected virtual void ProcessMockResponse(ref string? body, IList headers, ProxyRequestArgs e, MockResponse? matchingResponse) + protected virtual void ProcessMockResponse(ref string? body, IList headers, HttpRequestMessage request, MockResponse? matchingResponse) { if (string.IsNullOrEmpty(body)) { @@ -224,7 +219,7 @@ protected virtual void ProcessMockResponse(ref string? body, IList GetMatchingMockResponse(HttpRequestMessage request) { if (Configuration.NoMocks || Configuration.Mocks is null || @@ -290,44 +285,47 @@ Configuration.Mocks is null || return null; } - var mockResponse = Configuration.Mocks.FirstOrDefault(mockResponse => + var requestUrl = request.RequestUri!.ToString(); + var requestMethod = request.Method.Method; + + MockResponse? matchingMockResponse = null; + + foreach (var mockResponse in Configuration.Mocks) { if (mockResponse.Request is null) { - return false; + continue; } - if (mockResponse.Request.Method != request.Method) + if (mockResponse.Request.Method != requestMethod) { - return false; + continue; } - if (mockResponse.Request.Url == request.Url && - HasMatchingBody(mockResponse, request) && - IsNthRequest(mockResponse)) + var urlMatches = false; + if (mockResponse.Request.Url == requestUrl) { - return true; + urlMatches = true; } - - // check if the URL contains a wildcard - // if it doesn't, it's not a match for the current request for sure - if (!mockResponse.Request.Url.Contains('*', StringComparison.OrdinalIgnoreCase)) + else if (mockResponse.Request.Url.Contains('*', StringComparison.OrdinalIgnoreCase)) { - return false; + // turn mock URL with wildcard into a regex and match against the request URL + urlMatches = Regex.IsMatch(requestUrl, ProxyUtils.PatternToRegex(mockResponse.Request.Url)); } - // turn mock URL with wildcard into a regex and match against the request URL - return Regex.IsMatch(request.Url, ProxyUtils.PatternToRegex(mockResponse.Request.Url)) && - HasMatchingBody(mockResponse, request) && - IsNthRequest(mockResponse); - }); + if (urlMatches && await HasMatchingBody(mockResponse, request) && IsNthRequest(mockResponse)) + { + matchingMockResponse = mockResponse; + break; + } + } - if (mockResponse is not null && mockResponse.Request is not null) + if (matchingMockResponse is not null && matchingMockResponse.Request is not null) { - _ = _appliedMocks.AddOrUpdate(mockResponse.Request.Url, 1, (_, value) => ++value); + _ = _appliedMocks.AddOrUpdate(matchingMockResponse.Request.Url, 1, (_, value) => ++value); } - return mockResponse; + return matchingMockResponse; } private bool IsNthRequest(MockResponse mockResponse) @@ -344,12 +342,11 @@ private bool IsNthRequest(MockResponse mockResponse) return mockResponse.Request.Nth == nth; } - private void ProcessMockResponseInternal(ProxyRequestArgs e, MockResponse matchingResponse) + private HttpResponseMessage ProcessMockResponseInternal(HttpRequestMessage request, MockResponse matchingResponse, RequestId requestId) { string? body = null; - var requestId = Guid.NewGuid().ToString(); var requestDate = DateTime.Now.ToString(CultureInfo.CurrentCulture); - var headers = ProxyUtils.BuildGraphResponseHeaders(e.Session.HttpClient.Request, requestId, requestDate); + var headers = ProxyUtils.BuildGraphResponseHeaders(request, requestId, requestDate); var statusCode = HttpStatusCode.OK; if (matchingResponse.Response?.StatusCode is not null) { @@ -368,7 +365,9 @@ private void ProcessMockResponseInternal(ProxyRequestArgs e, MockResponse matchi headers.Add(new("content-type", "application/json")); } - if (e.SessionData.TryGetValue(nameof(RateLimitingPlugin), out var pluginData) && + // Check for rate limiting headers from RateLimitingPlugin using new storage API + var requestData = proxyStorage.GetRequestData(requestId); + if (requestData.TryGetValue(nameof(Behavior.RateLimitingPlugin), out var pluginData) && pluginData is List rateLimitingHeaders) { ProxyUtils.MergeHeaders(headers, rateLimitingHeaders); @@ -388,17 +387,21 @@ private void ProcessMockResponseInternal(ProxyRequestArgs e, MockResponse matchi var filePath = Path.Combine(Path.GetDirectoryName(Configuration.MocksFile) ?? "", ProxyUtils.ReplacePathTokens(bodyString.Trim('"')[1..])); if (!File.Exists(filePath)) { - Logger.LogError("File {FilePath} not found. Serving file path in the mock response", filePath); body = bodyString; } else { var bodyBytes = File.ReadAllBytes(filePath); - ProcessMockResponse(ref bodyBytes, headers, e, matchingResponse); - e.Session.GenericResponse(bodyBytes, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); - Logger.LogRequest($"{matchingResponse.Response.StatusCode ?? 200} {matchingResponse.Request?.Url}", MessageType.Mocked, new LoggingContext(e.Session)); - return; + ProcessMockResponse(ref bodyBytes, headers, request, matchingResponse); + var response = new HttpResponseMessage(statusCode); + foreach (var header in headers) + { + _ = response.Headers.TryAddWithoutValidation(header.Name, header.Value); + } + response.Content = new ByteArrayContent(bodyBytes); + Logger.LogRequest($"{matchingResponse.Response.StatusCode ?? 200} {matchingResponse.Request?.Url}", MessageType.Mocked, request); + return response; } } else @@ -416,10 +419,22 @@ private void ProcessMockResponseInternal(ProxyRequestArgs e, MockResponse matchi _ = headers.Remove(contentTypeHeader); } } - ProcessMockResponse(ref body, headers, e, matchingResponse); - e.Session.GenericResponse(body ?? string.Empty, statusCode, headers.Select(h => new HttpHeader(h.Name, h.Value))); - Logger.LogRequest($"{matchingResponse.Response?.StatusCode ?? 200} {matchingResponse.Request?.Url}", MessageType.Mocked, new(e.Session)); + ProcessMockResponse(ref body, headers, request, matchingResponse); + + var httpResponse = new HttpResponseMessage(statusCode); + foreach (var header in headers) + { + _ = httpResponse.Headers.TryAddWithoutValidation(header.Name, header.Value); + } + + if (!string.IsNullOrEmpty(body)) + { + httpResponse.Content = new StringContent(body, Encoding.UTF8, "application/json"); + } + + Logger.LogRequest($"{matchingResponse.Response?.StatusCode ?? 200} {matchingResponse.Request?.Url}", MessageType.Mocked, request); + return httpResponse; } private async Task GenerateMocksFromHttpResponsesAsync(ParseResult parseResult) @@ -491,9 +506,9 @@ await File.WriteAllTextAsync( Logger.LogTrace("Left {Method}", nameof(GenerateMocksFromHttpResponsesAsync)); } - private static bool HasMatchingBody(MockResponse mockResponse, Request request) + private static async Task HasMatchingBody(MockResponse mockResponse, HttpRequestMessage request) { - if (request.Method == "GET") + if (request.Method == HttpMethod.Get) { // GET requests don't have a body so we can't match on it return true; @@ -505,13 +520,21 @@ private static bool HasMatchingBody(MockResponse mockResponse, Request request) return true; } - if (!request.HasBody || string.IsNullOrEmpty(request.BodyString)) + if (request.Content is null) + { + // mock defines a body fragment but the request has no body + // so it can't match + return false; + } + + var requestBody = await request.Content.ReadAsStringAsync(); + if (string.IsNullOrEmpty(requestBody)) { // mock defines a body fragment but the request has no body // so it can't match return false; } - return request.BodyString.Contains(mockResponse.Request.BodyFragment, StringComparison.OrdinalIgnoreCase); + return requestBody.Contains(mockResponse.Request.BodyFragment, StringComparison.OrdinalIgnoreCase); } } diff --git a/DevProxy.Plugins/Mocking/OpenAIMockResponsePlugin.cs b/DevProxy.Plugins/Mocking/OpenAIMockResponsePlugin.cs index 464c5ac9..a1d1c597 100644 --- a/DevProxy.Plugins/Mocking/OpenAIMockResponsePlugin.cs +++ b/DevProxy.Plugins/Mocking/OpenAIMockResponsePlugin.cs @@ -9,7 +9,6 @@ using Microsoft.Extensions.Logging; using System.Net; using System.Text.Json; -using Titanium.Web.Proxy.Models; namespace DevProxy.Plugins.Mocking; @@ -32,76 +31,68 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell } } - public override async Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + public override Func>? OnRequestAsync => async (args, cancellationToken) => { - Logger.LogTrace("{Method} called", nameof(BeforeRequestAsync)); + Logger.LogTrace("{Method} called", nameof(OnRequestAsync)); - ArgumentNullException.ThrowIfNull(e); - - if (!e.HasRequestUrlMatch(UrlsToWatch)) - { - Logger.LogRequest("URL not matched", MessageType.Skipped, new LoggingContext(e.Session)); - return; - } - if (e.ResponseState.HasBeenSet) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) { - Logger.LogRequest("Response already set", MessageType.Skipped, new LoggingContext(e.Session)); - return; + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - var request = e.Session.HttpClient.Request; - if (request.Method is null || - !request.Method.Equals("POST", StringComparison.OrdinalIgnoreCase) || - !request.HasBody) + if (args.Request.Method != HttpMethod.Post || args.Request.Content is null) { - Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Request is not a POST request with a body", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } - if (!TryGetOpenAIRequest(request.BodyString, out var openAiRequest)) + var requestBody = await args.Request.Content.ReadAsStringAsync(cancellationToken); + if (!TryGetOpenAIRequest(requestBody, out var openAiRequest)) { - Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, new(e.Session)); - return; + Logger.LogRequest("Skipping non-OpenAI request", MessageType.Skipped, args.Request); + return PluginResponse.Continue(); } if (openAiRequest is OpenAICompletionRequest completionRequest) { if ((await languageModelClient.GenerateCompletionAsync(completionRequest.Prompt, null, cancellationToken)) is not ILanguageModelCompletionResponse lmResponse) { - return; + return PluginResponse.Continue(); } if (lmResponse.ErrorMessage is not null) { Logger.LogError("Error from local language model: {Error}", lmResponse.ErrorMessage); - return; + return PluginResponse.Continue(); } var openAiResponse = lmResponse.ConvertToOpenAIResponse(); - SendMockResponse(openAiResponse, lmResponse.RequestUrl ?? string.Empty, e); + var httpResponse = CreateMockResponse(openAiResponse, lmResponse.RequestUrl ?? string.Empty, args.Request); + return PluginResponse.Respond(httpResponse); } else if (openAiRequest is OpenAIChatCompletionRequest chatRequest) { if ((await languageModelClient .GenerateChatCompletionAsync(chatRequest.Messages, null, cancellationToken)) is not ILanguageModelCompletionResponse lmResponse) { - return; + return PluginResponse.Continue(); } if (lmResponse.ErrorMessage is not null) { Logger.LogError("Error from local language model: {Error}", lmResponse.ErrorMessage); - return; + return PluginResponse.Continue(); } var openAiResponse = lmResponse.ConvertToOpenAIResponse(); - SendMockResponse(openAiResponse, lmResponse.RequestUrl ?? string.Empty, e); + var httpResponse = CreateMockResponse(openAiResponse, lmResponse.RequestUrl ?? string.Empty, args.Request); + return PluginResponse.Respond(httpResponse); } else { Logger.LogError("Unknown OpenAI request type."); + return PluginResponse.Continue(); } - - Logger.LogTrace("Left {Name}", nameof(BeforeRequestAsync)); - } + }; private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) { @@ -142,18 +133,21 @@ private bool TryGetOpenAIRequest(string content, out OpenAIRequest? request) } } - private void SendMockResponse(OpenAIResponse response, string localLmUrl, ProxyRequestArgs e) where TResponse : OpenAIResponse + private HttpResponseMessage CreateMockResponse(OpenAIResponse response, string localLmUrl, HttpRequestMessage originalRequest) where TResponse : OpenAIResponse { - e.Session.GenericResponse( - // we need this cast or else the JsonSerializer drops derived properties - JsonSerializer.Serialize((TResponse)response, ProxyUtils.JsonSerializerOptions), - HttpStatusCode.OK, - [ - new HttpHeader("content-type", "application/json"), - new HttpHeader("access-control-allow-origin", "*") - ] - ); - e.ResponseState.HasBeenSet = true; - Logger.LogRequest($"200 {localLmUrl}", MessageType.Mocked, new(e.Session)); + var httpResponse = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent( + // we need this cast or else the JsonSerializer drops derived properties + JsonSerializer.Serialize((TResponse)response, ProxyUtils.JsonSerializerOptions), + System.Text.Encoding.UTF8, + "application/json" + ) + }; + + httpResponse.Headers.Add("access-control-allow-origin", "*"); + + Logger.LogRequest($"200 {localLmUrl}", MessageType.Mocked, originalRequest); + return httpResponse; } } diff --git a/DevProxy.Plugins/Reporters/BaseReporter.cs b/DevProxy.Plugins/Reporters/BaseReporter.cs index 84fc5866..31570be6 100644 --- a/DevProxy.Plugins/Reporters/BaseReporter.cs +++ b/DevProxy.Plugins/Reporters/BaseReporter.cs @@ -11,19 +11,21 @@ namespace DevProxy.Plugins.Reporters; public abstract class BaseReporter( ILogger logger, - ISet urlsToWatch) : BasePlugin(logger, urlsToWatch) + ISet urlsToWatch, + IProxyStorage proxyStorage) : BasePlugin(logger, urlsToWatch) { public abstract string FileExtension { get; } - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); ArgumentNullException.ThrowIfNull(e); - await base.AfterRecordingStopAsync(e, cancellationToken); + //await base.AfterRecordingStopAsync(e, cancellationToken); - if (!e.GlobalData.TryGetValue(ProxyUtils.ReportsKey, out var value) || + if (!proxyStorage.GlobalData.TryGetValue(ProxyUtils.ReportsKey, out var value) || value is not Dictionary reports || reports.Count == 0) { diff --git a/DevProxy.Plugins/Reporters/JsonReporter.cs b/DevProxy.Plugins/Reporters/JsonReporter.cs index 8bc79b30..4539b158 100644 --- a/DevProxy.Plugins/Reporters/JsonReporter.cs +++ b/DevProxy.Plugins/Reporters/JsonReporter.cs @@ -12,7 +12,8 @@ namespace DevProxy.Plugins.Reporters; public class JsonReporter( ILogger logger, - ISet urlsToWatch) : BaseReporter(logger, urlsToWatch) + ISet urlsToWatch, + IProxyStorage proxyStorage) : BaseReporter(logger, urlsToWatch, proxyStorage) { private string _fileExtension = ".json"; diff --git a/DevProxy.Plugins/Reporters/MarkdownReporter.cs b/DevProxy.Plugins/Reporters/MarkdownReporter.cs index fb9e9069..6f886963 100644 --- a/DevProxy.Plugins/Reporters/MarkdownReporter.cs +++ b/DevProxy.Plugins/Reporters/MarkdownReporter.cs @@ -10,7 +10,8 @@ namespace DevProxy.Plugins.Reporters; public class MarkdownReporter( ILogger logger, - ISet urlsToWatch) : BaseReporter(logger, urlsToWatch) + ISet urlsToWatch, + IProxyStorage proxyStorage) : BaseReporter(logger, urlsToWatch, proxyStorage) { public override string Name => nameof(MarkdownReporter); public override string FileExtension => ".md"; diff --git a/DevProxy.Plugins/Reporters/PlainTextReporter.cs b/DevProxy.Plugins/Reporters/PlainTextReporter.cs index d5d51681..6c206ee9 100644 --- a/DevProxy.Plugins/Reporters/PlainTextReporter.cs +++ b/DevProxy.Plugins/Reporters/PlainTextReporter.cs @@ -10,7 +10,8 @@ namespace DevProxy.Plugins.Reporters; public class PlainTextReporter( ILogger logger, - ISet urlsToWatch) : BaseReporter(logger, urlsToWatch) + ISet urlsToWatch, + IProxyStorage proxyStorage) : BaseReporter(logger, urlsToWatch, proxyStorage) { public override string Name => nameof(PlainTextReporter); public override string FileExtension => ".txt"; diff --git a/DevProxy.Plugins/Reporting/ApiCenterMinimalPermissionsPlugin.cs b/DevProxy.Plugins/Reporting/ApiCenterMinimalPermissionsPlugin.cs index 24a69a25..e47b3d2f 100644 --- a/DevProxy.Plugins/Reporting/ApiCenterMinimalPermissionsPlugin.cs +++ b/DevProxy.Plugins/Reporting/ApiCenterMinimalPermissionsPlugin.cs @@ -29,13 +29,15 @@ public sealed class ApiCenterMinimalPermissionsPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { private ApiCenterClient? _apiCenterClient; private Api[]? _apis; @@ -81,8 +83,8 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell } Logger.LogDebug("Plugin {Plugin} auth confirmed...", Name); } - - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -92,9 +94,9 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation .Where(l => l.MessageType == MessageType.InterceptedRequest && !l.Message.StartsWith("OPTIONS", StringComparison.OrdinalIgnoreCase) && - l.Context?.Session is not null && - ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri) && - l.Context.Session.HttpClient.Request.Headers.Any(h => h.Name.Equals("authorization", StringComparison.OrdinalIgnoreCase)) + l.Request is not null && + ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Request.RequestUri?.AbsoluteUri ?? string.Empty) && + l.Request.Headers.Authorization is not null ); if (!interceptedRequests.Any()) { @@ -199,7 +201,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation Errors = [.. errors] }; - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } diff --git a/DevProxy.Plugins/Reporting/ApiCenterProductionVersionPlugin.cs b/DevProxy.Plugins/Reporting/ApiCenterProductionVersionPlugin.cs index 1d063cde..1da94509 100644 --- a/DevProxy.Plugins/Reporting/ApiCenterProductionVersionPlugin.cs +++ b/DevProxy.Plugins/Reporting/ApiCenterProductionVersionPlugin.cs @@ -28,13 +28,15 @@ public sealed class ApiCenterProductionVersionPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { private ApiCenterClient? _apiCenterClient; private Api[]? _apis; @@ -80,7 +82,9 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell Logger.LogDebug("Plugin {Plugin} auth confirmed...", Name); } - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -89,8 +93,8 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation var interceptedRequests = e.RequestLogs .Where( l => l.MessageType == MessageType.InterceptedRequest && - l.Context?.Session is not null && - ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri) + l.Request is not null && + ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Request!.RequestUri!.AbsoluteUri) ); if (!interceptedRequests.Any()) { @@ -238,7 +242,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation } } - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } diff --git a/DevProxy.Plugins/Reporting/ExecutionSummaryPlugin.cs b/DevProxy.Plugins/Reporting/ExecutionSummaryPlugin.cs index 45e42da4..18fde51d 100644 --- a/DevProxy.Plugins/Reporting/ExecutionSummaryPlugin.cs +++ b/DevProxy.Plugins/Reporting/ExecutionSummaryPlugin.cs @@ -28,13 +28,14 @@ public sealed class ExecutionSummaryPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, proxyStorage) { private const string _groupByOptionName = "--summary-group-by"; private const string _requestsInterceptedMessage = "Requests intercepted"; @@ -75,7 +76,8 @@ public override void OptionsLoaded(OptionsLoadedArgs e) } } - public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -90,8 +92,8 @@ public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken var interceptedRequests = e.RequestLogs .Where( l => l.MessageType == MessageType.InterceptedRequest && - l.Context?.Session is not null && - ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri) + l.Request is not null && + ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Request.RequestUri?.AbsoluteUri ?? string.Empty) ); ExecutionSummaryPluginReportBase report = Configuration.GroupBy switch @@ -101,7 +103,7 @@ public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken _ => throw new NotImplementedException() }; - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); return Task.CompletedTask; @@ -230,9 +232,9 @@ private static string GetRequestMessage(RequestLog requestLog) => private static string GetMethodAndUrl(RequestLog requestLog) { - return requestLog.Context is not null - ? $"{requestLog.Context.Session.HttpClient.Request.Method} {requestLog.Context.Session.HttpClient.Request.RequestUri}" - : "Undefined"; + return requestLog.Request is not null + ? $"{requestLog.Request.Method.Method} {requestLog.Request.RequestUri}" + : $"{requestLog.Method} {requestLog.Url}"; } #pragma warning disable IDE0072 diff --git a/DevProxy.Plugins/Reporting/GraphMinimalPermissionsGuidancePlugin.cs b/DevProxy.Plugins/Reporting/GraphMinimalPermissionsGuidancePlugin.cs index 11ce07b9..01d97015 100644 --- a/DevProxy.Plugins/Reporting/GraphMinimalPermissionsGuidancePlugin.cs +++ b/DevProxy.Plugins/Reporting/GraphMinimalPermissionsGuidancePlugin.cs @@ -41,13 +41,14 @@ public sealed class GraphMinimalPermissionsGuidancePlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, proxyStorage) { private GraphUtils? _graphUtils; private readonly HttpClient _httpClient = httpClient; @@ -64,8 +65,8 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell InitializePermissionsToExclude(); } - - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -117,7 +118,9 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation if (ProxyUtils.IsGraphBatchUrl(uri)) { var graphVersion = ProxyUtils.IsGraphBetaUrl(uri) ? "beta" : "v1.0"; - requestsFromBatch = GetRequestsFromBatch(request.Context?.Session.HttpClient.Request.BodyString!, graphVersion, uri.Host); + var requestBody = request.Request?.Content != null ? + await request.Request.Content.ReadAsStringAsync(cancellationToken) : string.Empty; + requestsFromBatch = GetRequestsFromBatch(requestBody, graphVersion, uri.Host); } else { @@ -203,7 +206,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation await EvaluateMinimalScopesAsync(applicationEndpoints, rolesToEvaluate, GraphPermissionsType.Application, applicationPermissionsInfo, cancellationToken); } - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } @@ -332,13 +335,15 @@ private static (string method, string url)[] GetRequestsFromBatch(string batchBo /// private static (GraphPermissionsType type, IEnumerable permissions) GetPermissionsAndType(RequestLog request) { - var authHeader = request.Context?.Session.HttpClient.Request.Headers.GetFirstHeader("Authorization"); - if (authHeader == null) + // Try to get Authorization header from the new Request property first + var authHeaderValue = request.Request?.Headers.Authorization?.ToString(); + + if (string.IsNullOrEmpty(authHeaderValue)) { return (GraphPermissionsType.Application, []); } - var token = authHeader.Value.Replace("Bearer ", string.Empty, StringComparison.OrdinalIgnoreCase); + var token = authHeaderValue.Replace("Bearer ", string.Empty, StringComparison.OrdinalIgnoreCase); var tokenChunks = token.Split('.'); if (tokenChunks.Length != 3) { diff --git a/DevProxy.Plugins/Reporting/GraphMinimalPermissionsPlugin.cs b/DevProxy.Plugins/Reporting/GraphMinimalPermissionsPlugin.cs index 24b3aac1..73f1e586 100644 --- a/DevProxy.Plugins/Reporting/GraphMinimalPermissionsPlugin.cs +++ b/DevProxy.Plugins/Reporting/GraphMinimalPermissionsPlugin.cs @@ -26,13 +26,15 @@ public sealed class GraphMinimalPermissionsPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { private GraphUtils? _graphUtils; private readonly HttpClient _httpClient = httpClient; @@ -47,8 +49,8 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell _graphUtils = ActivatorUtilities.CreateInstance(e.ServiceProvider); } - - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -92,7 +94,8 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation if (ProxyUtils.IsGraphBatchUrl(uri)) { var graphVersion = ProxyUtils.IsGraphBetaUrl(uri) ? "beta" : "v1.0"; - var requestsFromBatch = GetRequestsFromBatch(request.Context?.Session.HttpClient.Request.BodyString!, graphVersion, uri.Host); + var bodyString = await request.Request!.Content!.ReadAsStringAsync(cancellationToken); + var requestsFromBatch = GetRequestsFromBatch(bodyString, graphVersion, uri.Host); endpoints.AddRange(requestsFromBatch); } else @@ -118,7 +121,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation var report = await DetermineMinimalScopesAsync(endpoints, cancellationToken); if (report is not null) { - StoreReport(report, e); + StoreReport(report); } Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); diff --git a/DevProxy.Plugins/Reporting/MinimalCsomPermissionsPlugin.cs b/DevProxy.Plugins/Reporting/MinimalCsomPermissionsPlugin.cs index 7dfd89c2..5add2a0d 100644 --- a/DevProxy.Plugins/Reporting/MinimalCsomPermissionsPlugin.cs +++ b/DevProxy.Plugins/Reporting/MinimalCsomPermissionsPlugin.cs @@ -24,13 +24,15 @@ public sealed class MinimalCsomPermissionsPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { private CsomTypesDefinitionLoader? _loader; @@ -60,8 +62,8 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell Logger.LogTrace("Left MinimalCsomPermissionsPlugin.RegisterAsync"); } - - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -72,8 +74,9 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation var interceptedRequests = e.RequestLogs .Where(l => l.MessageType == MessageType.InterceptedRequest && - l.Message.StartsWith("POST", StringComparison.OrdinalIgnoreCase) && - l.Message.Contains("/_vti_bin/client.svc/ProcessQuery", StringComparison.InvariantCultureIgnoreCase) + l.Request is not null && + l.Request.RequestUri!.AbsoluteUri.Contains("/_vti_bin/client.svc/ProcessQuery", StringComparison.InvariantCultureIgnoreCase) && + l.Request.Content is not null ); if (!interceptedRequests.Any()) { @@ -90,18 +93,13 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation { cancellationToken.ThrowIfCancellationRequested(); - if (request.Context == null) - { - continue; - } - - if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri)) + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, request.Request!.RequestUri!.AbsoluteUri)) { - Logger.LogDebug("URL not matched: {Url}", request.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri); + Logger.LogDebug("URL not matched: {Url}", request.Request.RequestUri.AbsoluteUri); continue; } - var requestBody = await request.Context.Session.GetRequestBodyAsString(cancellationToken); + var requestBody = await request.Request.Content!.ReadAsStringAsync(cancellationToken); if (string.IsNullOrEmpty(requestBody)) { continue; @@ -161,7 +159,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation Errors = [.. errors] }; - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } diff --git a/DevProxy.Plugins/Reporting/MinimalPermissionsGuidancePlugin.cs b/DevProxy.Plugins/Reporting/MinimalPermissionsGuidancePlugin.cs index 59d9ff9b..d3c659d5 100644 --- a/DevProxy.Plugins/Reporting/MinimalPermissionsGuidancePlugin.cs +++ b/DevProxy.Plugins/Reporting/MinimalPermissionsGuidancePlugin.cs @@ -40,13 +40,15 @@ public sealed class MinimalPermissionsGuidancePlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigSection) : + IConfigurationSection pluginConfigSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigSection) + pluginConfigSection, + proxyStorage) { private Dictionary? _apiSpecsByUrl; @@ -69,7 +71,8 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell } } - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -78,10 +81,10 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation var interceptedRequests = e.RequestLogs .Where(l => l.MessageType == MessageType.InterceptedRequest && - !l.Message.StartsWith("OPTIONS", StringComparison.OrdinalIgnoreCase) && - l.Context?.Session is not null && - ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri) && - l.Context.Session.HttpClient.Request.Headers.Any(h => h.Name.Equals("authorization", StringComparison.OrdinalIgnoreCase)) + l.Request is not null && + l.Request.Method != HttpMethod.Options && + ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Request!.RequestUri!.AbsoluteUri) && + l.Request.Headers.Authorization is not null ); if (!interceptedRequests.Any()) { @@ -176,7 +179,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation Errors = [.. errors] }; - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } diff --git a/DevProxy.Plugins/Reporting/MinimalPermissionsPlugin.cs b/DevProxy.Plugins/Reporting/MinimalPermissionsPlugin.cs index d4134458..b072014a 100644 --- a/DevProxy.Plugins/Reporting/MinimalPermissionsPlugin.cs +++ b/DevProxy.Plugins/Reporting/MinimalPermissionsPlugin.cs @@ -24,13 +24,15 @@ public sealed class MinimalPermissionsPlugin( ILogger logger, ISet urlsToWatch, IProxyConfiguration proxyConfiguration, - IConfigurationSection pluginConfigurationSection) : + IConfigurationSection pluginConfigurationSection, + IProxyStorage proxyStorage) : BaseReportingPlugin( httpClient, logger, urlsToWatch, proxyConfiguration, - pluginConfigurationSection) + pluginConfigurationSection, + proxyStorage) { private Dictionary? _apiSpecsByUrl; @@ -54,7 +56,8 @@ public override async Task InitializeAsync(InitArgs e, CancellationToken cancell } } - public override async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public async Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -63,9 +66,9 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation var interceptedRequests = e.RequestLogs .Where(l => l.MessageType == MessageType.InterceptedRequest && - !l.Message.StartsWith("OPTIONS", StringComparison.OrdinalIgnoreCase) && - l.Context?.Session is not null && - ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Context.Session.HttpClient.Request.RequestUri.AbsoluteUri) + l.Request is not null && + l.Request.Method != HttpMethod.Options && + ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Request!.RequestUri!.AbsoluteUri) ); if (!interceptedRequests.Any()) { @@ -142,7 +145,7 @@ public override async Task AfterRecordingStopAsync(RecordingArgs e, Cancellation Errors = [.. errors] }; - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); } diff --git a/DevProxy.Plugins/Reporting/UrlDiscoveryPlugin.cs b/DevProxy.Plugins/Reporting/UrlDiscoveryPlugin.cs index 7d70e1f5..395da1f9 100644 --- a/DevProxy.Plugins/Reporting/UrlDiscoveryPlugin.cs +++ b/DevProxy.Plugins/Reporting/UrlDiscoveryPlugin.cs @@ -11,11 +11,12 @@ namespace DevProxy.Plugins.Reporting; public sealed class UrlDiscoveryPlugin( ILogger logger, - ISet urlsToWatch) : BaseReportingPlugin(logger, urlsToWatch) + ISet urlsToWatch, IProxyStorage proxyStorage) : BaseReportingPlugin(logger, urlsToWatch, proxyStorage) { public override string Name => nameof(UrlDiscoveryPlugin); - public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) + public override Func? HandleRecordingStopAsync => AfterRecordingStopAsync; + public Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken cancellationToken) { Logger.LogTrace("{Method} called", nameof(AfterRecordingStopAsync)); @@ -28,19 +29,18 @@ public override Task AfterRecordingStopAsync(RecordingArgs e, CancellationToken } var requestLogs = e.RequestLogs - .Where(l => ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Context?.Session.HttpClient.Request.RequestUri.AbsoluteUri ?? "")); + .Where(l => ProxyUtils.MatchesUrlToWatch(UrlsToWatch, l.Request!.RequestUri!.AbsoluteUri ?? "")); UrlDiscoveryPluginReport report = new() { Data = [ .. requestLogs - .Where(log => log.Context is not null) - .Select(log => log.Context!.Session.HttpClient.Request.RequestUri.ToString()).Distinct().Order() + .Select(log => log.Request!.RequestUri!.AbsoluteUri).Distinct().Order() ] }; - StoreReport(report, e); + StoreReport(report); Logger.LogTrace("Left {Name}", nameof(AfterRecordingStopAsync)); return Task.CompletedTask; diff --git a/DevProxy.Plugins/Utils/GraphUtils.cs b/DevProxy.Plugins/Utils/GraphUtils.cs index 27ef95b6..70843438 100644 --- a/DevProxy.Plugins/Utils/GraphUtils.cs +++ b/DevProxy.Plugins/Utils/GraphUtils.cs @@ -5,7 +5,6 @@ using DevProxy.Plugins.Models; using Microsoft.Extensions.Logging; using System.Net.Http.Json; -using Titanium.Web.Proxy.Http; namespace DevProxy.Plugins.Utils; @@ -17,7 +16,7 @@ sealed class GraphUtils( private readonly ILogger _logger = logger; // throttle requests per workload - public static string BuildThrottleKey(Request r) => BuildThrottleKey(r.RequestUri); + public static string BuildThrottleKey(HttpRequestMessage r) => BuildThrottleKey(r.RequestUri!); public static string BuildThrottleKey(Uri uri) { diff --git a/DevProxy.Plugins/Utils/MinimalPermissionsUtils.cs b/DevProxy.Plugins/Utils/MinimalPermissionsUtils.cs index a89523a9..77a5b82d 100644 --- a/DevProxy.Plugins/Utils/MinimalPermissionsUtils.cs +++ b/DevProxy.Plugins/Utils/MinimalPermissionsUtils.cs @@ -25,7 +25,9 @@ public static string[] GetScopesFromToken(string? jwtToken, ILogger logger) try { - var token = jwtToken.Split(' ')[1]; + var token = jwtToken!.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase) + ? jwtToken["Bearer ".Length..].Trim() + : jwtToken.Trim(); var handler = new JwtSecurityTokenHandler(); var jsonToken = handler.ReadToken(token) as JwtSecurityToken; var scopes = jsonToken?.Claims diff --git a/DevProxy.Plugins/inventory.md b/DevProxy.Plugins/inventory.md new file mode 100644 index 00000000..66faccde --- /dev/null +++ b/DevProxy.Plugins/inventory.md @@ -0,0 +1,384 @@ +# DevProxy Plugins Inventory + +This document provides an inventory of all plugins in the DevProxy.Plugins project and their implemented methods. This inventory was created to assist with the migration from the old event-based API to the new functional API. + +## ApiCenterMinimalPermissionsPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync +- AfterRecordingStopAsync + +**Behavior:** Read-only (collects API permission data for reporting) + +## ApiCenterOnboardingPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Read-only (collects API metadata for reporting) + +## ApiCenterProductionVersionPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Read-only (collects API version information for reporting) + +## AuthPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies responses (returns 401/403 for unauthorized requests) + +## CachingGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- BeforeRequestAsync + +**Behavior:** Read-only (analyzes request patterns and provides caching guidance) + +## CrudApiPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies responses (creates mock CRUD API responses) + +## DevToolsPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- InitializeAsync +- BeforeRequestAsync +- BeforeResponseAsync +- AfterResponseAsync +- AfterRequestLogAsync + +**Behavior:** Read-only (captures request/response data for developer tools) + +## EntraMockResponsePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies responses (provides mock Entra ID responses) + +## ExecutionSummaryPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- InitializeAsync +- BeforeRequestAsync +- AfterRecordingStopAsync + +**Behavior:** Read-only (collects execution statistics for reporting) + +## GenericRandomErrorPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- InitializeAsync +- OnRequestAsync (NEW API - MIGRATED) + +**Behavior:** Modifies responses (generates random error responses) + +## GraphBetaSupportGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- OnRequestLogAsync (NEW API - MIGRATED) + +**Behavior:** Read-only (provides guidance about beta API usage) + +## GraphClientRequestIdGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- OnRequestLogAsync (NEW API - MIGRATED) + +**Behavior:** Read-only (provides guidance about request ID headers) + +## GraphConnectorGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- OnRequestLogAsync (NEW API - MIGRATED) + +**Behavior:** Read-only (provides guidance about Graph connector usage) + +## GraphMinimalPermissionsGuidancePlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync +- AfterRecordingStopAsync + +**Behavior:** Read-only (analyzes and reports on Graph API permissions) + +## GraphMinimalPermissionsPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Read-only (collects Graph API permission data for reporting) + +## GraphMockResponsePlugin + +**Base Class:** MockResponsePlugin +**Methods Implemented:** +- BeforeRequestAsync + +**Behavior:** Modifies responses (provides mock Graph API responses, including batch processing) + +## GraphRandomErrorPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- OnRequestAsync (NEW API - MIGRATED) + +**Behavior:** Modifies responses (generates random Graph API error responses) + +## GraphSdkGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- OnResponseLogAsync (NEW API - MIGRATED) + +**Behavior:** Read-only (provides guidance about using Graph SDKs) + +## GraphSelectGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- AfterResponseAsync + +**Behavior:** Read-only (provides guidance about Graph $select optimization) + +## HttpFileGeneratorPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync +- AfterRecordingStopAsync + +**Behavior:** Read-only (generates HTTP files from recorded requests) + +## LanguageModelFailurePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies responses (simulates AI/ML service failures) + +## LanguageModelRateLimitingPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync +- BeforeResponseAsync + +**Behavior:** Modifies responses (enforces token-based rate limiting for AI services) + +## LatencyPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- BeforeRequestAsync + +**Behavior:** Read-only (adds artificial delay to requests, doesn't modify responses) + +## MinimalCsomPermissionsPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Read-only (analyzes SharePoint CSOM permissions for reporting) + +## MinimalPermissionsGuidancePlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Read-only (provides permission optimization guidance) + +## MinimalPermissionsPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Read-only (collects API permission data for reporting) + +## MockGeneratorPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- AfterRecordingStopAsync + +**Behavior:** Read-only (generates mock response files from recorded requests) + +## MockRequestPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies responses (provides mock request/response functionality) + +## MockResponsePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies responses (provides comprehensive mock response functionality) + +## ODSPSearchGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- OnRequestLogAsync (NEW API - MIGRATED) + +**Behavior:** Read-only (provides guidance about SharePoint search optimization) + +## ODataPagingGuidancePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- OnRequestLogAsync (NEW API - MIGRATED) +- OnResponseLogAsync (NEW API - MIGRATED) + +**Behavior:** Read-only (provides guidance about OData paging patterns) + +## OpenAIMockResponsePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies responses (provides mock OpenAI API responses using local language models) + +## OpenAITelemetryPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync +- AfterResponseAsync +- AfterRecordingStopAsync + +**Behavior:** Read-only (collects OpenAI API usage telemetry for reporting) + +## OpenApiSpecGeneratorPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync +- AfterRecordingStopAsync + +**Behavior:** Read-only (generates OpenAPI specifications from recorded requests) + +## RateLimitingPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- GetOptions +- OptionsLoaded +- InitializeAsync +- BeforeRequestAsync +- BeforeResponseAsync + +**Behavior:** Modifies responses (enforces rate limits and adds rate limit headers) + +## RetryAfterPlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- BeforeRequestAsync + +**Behavior:** Modifies responses (throttles requests that don't respect Retry-After headers) + +## RewritePlugin + +**Base Class:** BasePlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync + +**Behavior:** Modifies requests (rewrites request URLs, doesn't modify responses directly) + +## TypeSpecGeneratorPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- InitializeAsync +- BeforeRequestAsync +- AfterRecordingStopAsync + +**Behavior:** Read-only (generates TypeSpec definitions from recorded requests) + +## UrlDiscoveryPlugin + +**Base Class:** BaseReportingPlugin +**Methods Implemented:** +- BeforeRequestAsync + +**Behavior:** Read-only (discovers and reports API URLs) + +--- + +## Summary + +- **Total Plugins:** 38 +- **Plugins using BeforeRequestAsync:** 26 (decreased by 4 due to ODataPagingGuidancePlugin, ODSPSearchGuidancePlugin, and GenericRandomErrorPlugin migrations) +- **Plugins using BeforeResponseAsync:** 2 (DevToolsPlugin, RateLimitingPlugin) +- **Plugins using AfterResponseAsync:** 5 (decreased by 2 due to ODataPagingGuidancePlugin and GraphSdkGuidancePlugin migrations: GraphSelectGuidancePlugin, OpenAITelemetryPlugin) +- **Plugins using AfterRequestLogAsync:** 1 (DevToolsPlugin) +- **Plugins using AfterRecordingStopAsync:** 8 (ApiCenterMinimalPermissionsPlugin, ExecutionSummaryPlugin, GraphMinimalPermissionsGuidancePlugin, HttpFileGeneratorPlugin, MockGeneratorPlugin, OpenAITelemetryPlugin, OpenApiSpecGeneratorPlugin, TypeSpecGeneratorPlugin) +- **Plugins using OnRequestAsync (NEW API):** 2 (GraphRandomErrorPlugin, GenericRandomErrorPlugin - migrated) +- **Plugins using OnRequestLogAsync (NEW API):** 6 (GraphBetaSupportGuidancePlugin, CachingGuidancePlugin, GraphClientRequestIdGuidancePlugin, GraphConnectorGuidancePlugin, ODataPagingGuidancePlugin, ODSPSearchGuidancePlugin - migrated) +- **Plugins using OnResponseLogAsync (NEW API):** 2 (ODataPagingGuidancePlugin, GraphSdkGuidancePlugin - migrated) \ No newline at end of file diff --git a/DevProxy.Plugins/migration.md b/DevProxy.Plugins/migration.md new file mode 100644 index 00000000..50f3a46a --- /dev/null +++ b/DevProxy.Plugins/migration.md @@ -0,0 +1,748 @@ +# Plugin Migration Guide: From Event-Based to Functional API + +This document provides detailed guidance on migrating DevProxy plugins from the old event-based API to the new functional API pattern. + +## Overview of API Changes + +The DevProxy plugin architecture is transitioning from an event-based model to a functional model for better control flow and testability. + +### Old API (Event-Based) +```csharp +public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) +{ + // Logic to decide whether to intercept + if (!ShouldIntercept(e)) + { + return Task.CompletedTask; + } + + // Modify response directly through session + e.Session.GenericResponse(body, statusCode, headers); + e.ResponseState.HasBeenSet = true; + + return Task.CompletedTask; +} + +public override Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) +{ + // Process response before it's sent + return Task.CompletedTask; +} + +public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) +{ + // Process response after it's sent (read-only) + return Task.CompletedTask; +} +``` + +### New API (Functional) +```csharp +// For plugins that need to modify requests or responses +public override Func>? OnRequestAsync => async (args, cancellationToken) => +{ + // Logic to decide whether to intercept + if (!ShouldIntercept(args.Request)) + { + return PluginResponse.Continue(); + } + + // Create and return response + var response = new HttpResponseMessage(HttpStatusCode.BadRequest) + { + Content = new StringContent(body) + }; + + return PluginResponse.Respond(response); +}; + +// For guidance plugins that only need to log or analyze requests +public override Func? OnRequestLogAsync => async (args, cancellationToken) => +{ + // Analyze request and provide guidance + if (ShouldProvideGuidance(args.Request)) + { + Logger.LogRequest("Guidance message", MessageType.Tip, args.Request); + } +}; + +// For plugins that need to modify responses from remote server +public override Func>? OnResponseAsync => async (args, cancellationToken) => +{ + // Process response and optionally modify it + // Return null to continue, or PluginResponse to modify + return null; +}; + +// For guidance plugins that only need to log or analyze responses +public override Func? OnResponseLogAsync => async (args, cancellationToken) => +{ + // Analyze response and provide guidance + if (ShouldProvideGuidance(args.HttpResponseMessage)) + { + Logger.LogRequest("Response guidance message", MessageType.Tip, args.HttpRequestMessage, args.RequestId); + } +}; +``` + +## Key Differences + +### 1. Input Arguments +- **Old API:** `ProxyRequestArgs e` and `ProxyResponseArgs e` containing session, response state, and global data +- **New API:** + - `RequestArguments args` containing `HttpRequestMessage` and `RequestId` + - `ResponseArguments args` containing `HttpRequestMessage`, `HttpResponseMessage` and `RequestId` + +### 2. Return Values +- **Old API:** `Task` (void) - side effects through `e.Session` and `e.ResponseState` +- **New API:** + - `Task` for `OnRequestAsync` - explicit return values to control flow + - `Task` for `OnRequestLogAsync` - read-only logging/analysis of requests + - `Task` for `OnResponseAsync` - response modification (return null to continue) + - `Task` for `OnResponseLogAsync` - read-only logging/analysis of responses + +### 3. Response Creation +- **Old API:** Direct manipulation of session: `e.Session.GenericResponse(...)` +- **New API:** Create and return `HttpResponseMessage`: `PluginResponse.Respond(response)` + +### 4. Flow Control +- **Old API:** Check `e.ResponseState.HasBeenSet` and set it to `true` +- **New API:** Return `PluginResponse.Continue()` or `PluginResponse.Respond(response)` + +### 5. Method Selection Guide +Choose the appropriate new API method based on your plugin's behavior: + +- **`OnRequestAsync`**: Use for plugins that need to intercept and potentially modify or respond to requests +- **`OnRequestLogAsync`**: Use for guidance plugins that only need to analyze requests and provide logging/guidance (cannot modify requests or responses) +- **`OnResponseAsync`**: Use for plugins that need to modify responses from the remote server +- **`OnResponseLogAsync`**: Use for guidance plugins that only need to analyze responses and provide logging/guidance (cannot modify responses) + +## Migration Steps + +### Step 1: Determine the Appropriate New Method + +**Response Modifying Plugins** → Use `OnRequestAsync`: +- MockResponsePlugin +- AuthPlugin +- RateLimitingPlugin +- GenericRandomErrorPlugin +- etc. + +**Guidance/Analysis Plugins** → Use `OnRequestLogAsync` or `OnResponseLogAsync`: +- CachingGuidancePlugin → `OnRequestLogAsync` +- GraphSdkGuidancePlugin → `OnResponseLogAsync` (analyzes responses from AfterResponseAsync) +- UrlDiscoveryPlugin → `OnRequestLogAsync` +- Most reporting plugins → `OnRequestLogAsync` + +**Response Modifying Plugins** → Use `OnResponseAsync`: +- Plugins that need to modify responses from the remote server + +### Step 2: Change Method Signature + +**For Response Modifying Plugins (OnRequestAsync):** +```csharp +// Before +public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + +// After +public override Func>? OnRequestAsync => async (args, cancellationToken) => +``` + +**For Request Guidance Plugins (OnRequestLogAsync):** +```csharp +// Before +public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + +// After +public override Func? OnRequestLogAsync => async (args, cancellationToken) => +``` + +**For Response Guidance Plugins (OnResponseLogAsync):** +```csharp +// Before +public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + +// After +public override Func? OnResponseLogAsync => async (args, cancellationToken) => +``` + +**For Response Modifying Plugins (OnResponseAsync):** +```csharp +// Before +public override Task BeforeResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + +// After +public override Func>? OnResponseAsync => async (args, cancellationToken) => +``` + +### Step 3: Update Input Data Access + +**Before (Request):** +```csharp +var request = e.Session.HttpClient.Request; +var url = request.RequestUri; +var method = request.Method; +var body = request.BodyString; +``` + +**After (Request):** +```csharp +var request = args.Request; +var url = request.RequestUri; +var method = request.Method.Method; +var body = await request.Content.ReadAsStringAsync(); +``` + +**Before (Response):** +```csharp +var request = e.Session.HttpClient.Request; +var response = e.Session.HttpClient.Response; +var statusCode = response.StatusCode; +var responseBody = response.BodyString; +``` + +**After (Response):** +```csharp +var request = args.Request; +var response = args.Response; +var statusCode = response.StatusCode; +var responseBody = await response.Content.ReadAsStringAsync(); +``` + +### Step 4: Update URL Matching Logic + +**Before:** +```csharp +if (!e.HasRequestUrlMatch(UrlsToWatch)) +{ + Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + return Task.CompletedTask; +} +``` + +**After (OnRequestAsync):** +```csharp +if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) +{ + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return PluginResponse.Continue(); +} +``` + +**After (OnRequestLogAsync/OnResponseLogAsync):** +```csharp +if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) +{ + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return; +} +``` + +### Step 5: Update Response State Checking + +**Before:** +```csharp +if (e.ResponseState.HasBeenSet) +{ + Logger.LogRequest("Response already set", MessageType.Skipped, new(e.Session)); + return Task.CompletedTask; +} +``` + +**After:** +```csharp +// Not needed in new API - flow control is handled by return values +// OnRequestAsync: Each plugin returns either Continue() or Respond() +// OnRequestLogAsync/OnResponseLogAsync: Cannot modify responses, so this check is irrelevant +// OnResponseAsync: Return null to continue, or PluginResponse to modify +``` + +### Step 6: Update Response Creation (OnRequestAsync and OnResponseAsync only) + +**Before:** +```csharp +var headers = new List +{ + new("Content-Type", "application/json"), + new("X-Custom", "value") +}; + +e.Session.GenericResponse(jsonBody, HttpStatusCode.BadRequest, headers); +e.ResponseState.HasBeenSet = true; +``` + +**After:** +```csharp +var response = new HttpResponseMessage(HttpStatusCode.BadRequest) +{ + Content = new StringContent(jsonBody, Encoding.UTF8, "application/json") +}; +response.Headers.Add("X-Custom", "value"); + +return PluginResponse.Respond(response); +``` + +### Step 7: Update Passthrough Logic + +**Before:** +```csharp +if (shouldPassThrough) +{ + Logger.LogRequest("Pass through", MessageType.Skipped, new(e.Session)); + return Task.CompletedTask; +} +``` + +**After (OnRequestAsync/OnResponseAsync):** +```csharp +if (shouldPassThrough) +{ + Logger.LogRequest("Pass through", MessageType.Skipped, args.Request, args.RequestId); // or args.HttpRequestMessage + return PluginResponse.Continue(); // or return null for OnResponseAsync +} +``` + +**After (OnRequestLogAsync/OnResponseLogAsync):** +```csharp +if (shouldSkip) +{ + Logger.LogRequest("Skipping analysis", MessageType.Skipped, args.Request, args.RequestId); // or args.HttpRequestMessage + return; +} +``` + +## Complete Migration Examples + +### Example 1: Response Modifying Plugin (OnRequestAsync) + +**Before (Old API):** +```csharp +public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) +{ + if (!e.HasRequestUrlMatch(UrlsToWatch)) + { + return Task.CompletedTask; + } + + if (e.ResponseState.HasBeenSet) + { + return Task.CompletedTask; + } + + if (ShouldFail()) + { + var error = GetRandomError(); + var body = JsonSerializer.Serialize(error.Body); + var headers = error.Headers.Select(h => new HttpHeader(h.Name, h.Value)); + + e.Session.GenericResponse(body, (HttpStatusCode)error.StatusCode, headers); + e.ResponseState.HasBeenSet = true; + } + + return Task.CompletedTask; +} +``` + +**After (New API):** +```csharp +public override Func>? OnRequestAsync => (args, cancellationToken) => +{ + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) + { + return Task.FromResult(PluginResponse.Continue()); + } + + if (!ShouldFail()) + { + return Task.FromResult(PluginResponse.Continue()); + } + + var error = GetRandomError(); + var response = new HttpResponseMessage((HttpStatusCode)error.StatusCode) + { + Content = new StringContent( + JsonSerializer.Serialize(error.Body), + Encoding.UTF8, + "application/json" + ) + }; + + foreach (var header in error.Headers) + { + response.Headers.Add(header.Name, header.Value); + } + + return Task.FromResult(PluginResponse.Respond(response)); +}; +``` + +### Example 2: Request Guidance Plugin (OnRequestLogAsync) + +**Before (Old API):** +```csharp +public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) +{ + if (!e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + return Task.CompletedTask; + } + + var request = e.Session.HttpClient.Request; + if (ShouldProvideGuidance(request)) + { + Logger.LogRequest("Consider using cache for better performance", MessageType.Tip, new(e.Session)); + } + + return Task.CompletedTask; +} +``` + +**After (New API):** +```csharp +public override Func? OnRequestLogAsync => (args, cancellationToken) => +{ + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return Task.CompletedTask; + } + + if (ShouldProvideGuidance(args.Request)) + { + Logger.LogRequest("Consider using cache for better performance", MessageType.Tip, args.Request, args.RequestId); + } + + return Task.CompletedTask; +}; +``` + +### Example 3: Response Guidance Plugin (OnResponseLogAsync) + +**Before (Old API):** +```csharp +public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) +{ + if (!e.HasRequestUrlMatch(UrlsToWatch)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, new(e.Session)); + return Task.CompletedTask; + } + + var response = e.Session.HttpClient.Response; + if (ShouldProvideGuidance(response)) + { + Logger.LogRequest("Consider optimizing your API queries", MessageType.Tip, new(e.Session)); + } + + return Task.CompletedTask; +} +``` + +**After (New API):** +```csharp +public override Func? OnResponseLogAsync => (args, cancellationToken) => +{ + if (!ProxyUtils.MatchesUrlToWatch(UrlsToWatch, args.Request.RequestUri)) + { + Logger.LogRequest("URL not matched", MessageType.Skipped, args.Request, args.RequestId); + return Task.CompletedTask; + } + + if (ShouldProvideGuidance(args.HttpResponseMessage)) + { + Logger.LogRequest("Consider optimizing your API queries", MessageType.Tip, args.Request, args.RequestId); + } + + return Task.CompletedTask; +}; +``` + +### Example 4: Plugin with Storage Requirements + +**Before (Old API):** +```csharp +public sealed class MyStoragePlugin( + ILogger logger, + ISet urlsToWatch) : BasePlugin(logger, urlsToWatch) +{ + public override Task BeforeRequestAsync(ProxyRequestArgs e, CancellationToken cancellationToken) + { + // Access global data + e.GlobalData["RequestCount"] = (int)(e.GlobalData.GetValueOrDefault("RequestCount", 0)) + 1; + + // Access session data + e.SessionData["RequestStartTime"] = DateTime.UtcNow; + + return Task.CompletedTask; + } + + public override Task AfterResponseAsync(ProxyResponseArgs e, CancellationToken cancellationToken) + { + // Use session data + if (e.SessionData.TryGetValue("RequestStartTime", out var startTime)) + { + var duration = DateTime.UtcNow - (DateTime)startTime; + Logger.LogInformation("Request took {Duration}ms", duration.TotalMilliseconds); + } + + return Task.CompletedTask; + } +} +``` + +**After (New API):** +```csharp +public sealed class MyStoragePlugin( + ILogger logger, + ISet urlsToWatch, + IProxyStorage proxyStorage) : BasePlugin(logger, urlsToWatch) +{ + private readonly IProxyStorage _proxyStorage = proxyStorage; + + public override Func? OnRequestLogAsync => (args, cancellationToken) => + { + // Access global data + _proxyStorage.GlobalData["RequestCount"] = (int)(_proxyStorage.GlobalData.GetValueOrDefault("RequestCount", 0)) + 1; + + // Access request-specific data + var requestData = _proxyStorage.GetRequestData(args.RequestId); + requestData["RequestStartTime"] = DateTime.UtcNow; + + return Task.CompletedTask; + }; + + public override Func? OnResponseLogAsync => (args, cancellationToken) => + { + // Use request-specific data + var requestData = _proxyStorage.GetRequestData(args.RequestId); + if (requestData.TryGetValue("RequestStartTime", out var startTime)) + { + var duration = DateTime.UtcNow - (DateTime)startTime; + Logger.LogInformation("Request took {Duration}ms", duration.TotalMilliseconds); + } + + return Task.CompletedTask; + }; +} +``` + +## Important Notes + +### 1. Logging Context +The logging context changes from `LoggingContext(e.Session)` to the appropriate request message: +```csharp +// Old +Logger.LogRequest("Message", MessageType.Info, new LoggingContext(e.Session)); + +// New (Request-based methods) +Logger.LogRequest("Message", MessageType.Info, args.Request, args.RequestId); + +// New (Response-based methods) +Logger.LogRequest("Message", MessageType.Info, args.Request, args.RequestId, args.Response); +``` + +### 2. Global Data and Session Data +Global data and session data access patterns will need to be reviewed as they may not be available in the new API. These features are now handled through dependency injection using the `IProxyStorage` interface. + +**For plugins that need global or request-specific storage:** + +Use constructor injection to access the `IProxyStorage` interface: + +```csharp +public sealed class MyPlugin( + ILogger logger, + ISet urlsToWatch, + IProxyStorage proxyStorage) : BasePlugin(logger, urlsToWatch) +{ + private readonly IProxyStorage _proxyStorage = proxyStorage; + + public override Func>? OnRequestAsync => (args, cancellationToken) => + { + // Access global data (shared across all requests) + _proxyStorage.GlobalData["MyKey"] = "MyValue"; + + // Access request-specific data using the request ID + var requestData = _proxyStorage.GetRequestData(args.RequestId); + requestData["RequestSpecificKey"] = "RequestSpecificValue"; + + return Task.FromResult(PluginResponse.Continue()); + }; +} +``` + +**Migration patterns:** + +```csharp +// Old API - Global Data +e.GlobalData["MyKey"] = "MyValue"; +var globalValue = e.GlobalData.GetValueOrDefault("MyKey"); + +// New API - Global Data +_proxyStorage.GlobalData["MyKey"] = "MyValue"; +var globalValue = _proxyStorage.GlobalData.GetValueOrDefault("MyKey"); + +// Old API - Session Data +e.SessionData["MyKey"] = "MyValue"; +var sessionValue = e.SessionData.GetValueOrDefault("MyKey"); + +// New API - Request Data +var requestData = _proxyStorage.GetRequestData(args.RequestId); +requestData["MyKey"] = "MyValue"; +var requestValue = requestData.GetValueOrDefault("MyKey"); +``` + +**Important notes about storage:** +- **Global data** persists across all requests and is shared between all plugins +- **Request data** is specific to a single request and is automatically cleaned up when the request completes +- Request data is accessed using the `RequestId` from the `RequestArguments` or `ResponseArguments` +- For reporting plugins that need to store reports, use global data as shown in `BaseReportingPlugin.StoreReport()` + +### 3. New API Benefits +The new API methods provide several advantages: +- **Better Control Flow**: Clear separation between modifying and logging operations +- **Clear Intent**: Method names explicitly indicate their purpose and capabilities +- **Performance**: Logging methods don't block critical paths +- **Separation of Concerns**: Clear distinction between modification and analysis logic + +### 4. Async Considerations +All new API methods expect functions that return Tasks, so you can use async/await within the lambda: +```csharp +public override Func>? OnRequestAsync => async (args, cancellationToken) => +{ + var data = await SomeAsyncOperation(cancellationToken); + // ... process data + return PluginResponse.Continue(); +}; +``` + +### 5. Error Handling +Error handling should be done within the function and appropriate responses returned: +```csharp +public override Func>? OnRequestAsync => async (args, cancellationToken) => +{ + try + { + // Plugin logic + return PluginResponse.Continue(); + } + catch (Exception ex) + { + Logger.LogError(ex, "Error in plugin"); + return PluginResponse.Continue(); // or return an error response + } +}; +``` + +## Migration instructions + +- We have compilation errors, so no need to try to build the project until all plugins are migrated. +- Instead of `string.Equals(args.Request.Method.Method, "OPTIONS", StringComparison.OrdinalIgnoreCase)`, use `args.Request.Method == HttpMethod.Options` for better performance. +- Summarize changes in max two lines + +### General Migration Steps + +1. Migrate plugin according to the new API method (OnRequestAsync, OnRequestLogAsync, OnResponseAsync, or OnResponseLogAsync). +2. Update inventory.md to reflect the new method, leave the old methods in place using strikethrough. +3. Update migration.md with the new migration status + +### For Response Modifying Plugins (OnRequestAsync): +- [ ] Update method signature from `BeforeRequestAsync` to `OnRequestAsync` +- [ ] Change return type from `Task` to `Task` +- [ ] Update input parameter from `ProxyRequestArgs` to `RequestArguments` +- [ ] Replace `e.Session.HttpClient.Request` with `args.Request` +- [ ] Replace `e.HasRequestUrlMatch()` with `ProxyUtils.MatchesUrlToWatch()` +- [ ] Remove `e.ResponseState.HasBeenSet` checks +- [ ] Replace `e.Session.GenericResponse()` with `HttpResponseMessage` creation +- [ ] Replace `e.ResponseState.HasBeenSet = true` with `PluginResponse.Respond()` +- [ ] Replace `return Task.CompletedTask` with `PluginResponse.Continue()` +- [ ] Update logging context from `LoggingContext(e.Session)` to `args.Request` +- [ ] Add `IProxyStorage` to constructor if plugin needs global or request-specific data storage +- [ ] Test the migrated plugin thoroughly + +### For Request Guidance Plugins (OnRequestLogAsync): +- [ ] Update method signature from `BeforeRequestAsync` to `OnRequestLogAsync` +- [ ] Keep return type as `Task` (no PluginResponse needed) +- [ ] Update input parameter from `ProxyRequestArgs` to `RequestArguments` +- [ ] Replace `e.Session.HttpClient.Request` with `args.Request` +- [ ] Replace `e.HasRequestUrlMatch()` with `ProxyUtils.MatchesUrlToWatch()` +- [ ] Remove any response modification logic (not allowed in OnRequestLogAsync) +- [ ] Update logging context from `LoggingContext(e.Session)` to `args.Request` +- [ ] Add `IProxyStorage` to constructor if plugin needs global or request-specific data storage +- [ ] Test the migrated plugin thoroughly + +### For Response Guidance Plugins (OnResponseLogAsync): +- [ ] Update method signature from `AfterResponseAsync` to `OnResponseLogAsync` +- [ ] Keep return type as `Task` (no PluginResponse needed) +- [ ] Update input parameter from `ProxyResponseArgs` to `ResponseArguments` +- [ ] Replace `e.Session.HttpClient.Request` with `args.Request` +- [ ] Replace `e.Session.HttpClient.Response` with `args.Response` +- [ ] Replace `e.HasRequestUrlMatch()` with `ProxyUtils.MatchesUrlToWatch()` +- [ ] Remove any response modification logic (not allowed in OnResponseLogAsync) +- [ ] Update logging context from `LoggingContext(e.Session)` to `args.Request` +- [ ] Add `IProxyStorage` to constructor if plugin needs global or request-specific data storage +- [ ] Test the migrated plugin thoroughly + +### For Response Modifying Plugins (OnResponseAsync): +- [ ] Update method signature from `BeforeResponseAsync` to `OnResponseAsync` +- [ ] Change return type from `Task` to `Task` +- [ ] Update input parameter from `ProxyResponseArgs` to `ResponseArguments` +- [ ] Replace `e.Session.HttpClient.Request` with `args.Request` +- [ ] Replace `e.Session.HttpClient.Response` with `args.Response` +- [ ] Replace `e.HasRequestUrlMatch()` with `ProxyUtils.MatchesUrlToWatch()` +- [ ] Remove `e.ResponseState.HasBeenSet` checks +- [ ] Return `null` to continue or `PluginResponse` to modify +- [ ] Update logging context from `LoggingContext(e.Session)` to `args.Request` +- [ ] Add `IProxyStorage` to constructor if plugin needs global or request-specific data storage +- [ ] Test the migrated plugin thoroughly + +## Plugin Migration Categorization + +Based on the inventory, here's how plugins should be migrated: + +### OnRequestAsync (Response Modifying - 17 plugins): +1. AuthPlugin +2. CrudApiPlugin +3. EntraMockResponsePlugin +4. ~~GenericRandomErrorPlugin~~ (MIGRATED) +5. GraphMockResponsePlugin +6. GraphRandomErrorPlugin (already migrated) +7. LanguageModelFailurePlugin +8. LanguageModelRateLimitingPlugin +9. MockRequestPlugin +10. MockResponsePlugin +11. OpenAIMockResponsePlugin +12. RateLimitingPlugin +13. RetryAfterPlugin + +### OnRequestLogAsync (Request Guidance/Analysis - 20+ plugins): +1. ApiCenterMinimalPermissionsPlugin +2. ApiCenterOnboardingPlugin +3. ApiCenterProductionVersionPlugin +4. CachingGuidancePlugin (MIGRATED) +5. ExecutionSummaryPlugin +6. GraphClientRequestIdGuidancePlugin (MIGRATED) +7. GraphConnectorGuidancePlugin (MIGRATED) +8. GraphMinimalPermissionsGuidancePlugin +9. GraphMinimalPermissionsPlugin +10. GraphSelectGuidancePlugin (MIGRATED) +11. HttpFileGeneratorPlugin +12. MinimalCsomPermissionsPlugin +13. MinimalPermissionsGuidancePlugin +14. MinimalPermissionsPlugin +15. ~~ODSPSearchGuidancePlugin~~ (MIGRATED) +16. OpenAITelemetryPlugin +17. OpenApiSpecGeneratorPlugin +18. TypeSpecGeneratorPlugin +19. UrlDiscoveryPlugin + +### OnResponseLogAsync (Response Guidance/Analysis - plugins analyzing responses): +1. ~~GraphSdkGuidancePlugin~~ (MIGRATED) +2. ~~ODataPagingGuidancePlugin~~ (MIGRATED) + +### Special Cases: +- **RewritePlugin**: Modifies requests before they proceed (may need custom handling) +- **DevToolsPlugin**: Uses multiple methods (BeforeRequestAsync, BeforeResponseAsync, AfterResponseAsync, AfterRequestLogAsync) +- **LatencyPlugin**: Adds delay but doesn't modify responses (could use OnRequestLogAsync) + +The new API methods enable better control flow by allowing the proxy to handle modification and logging operations separately, improving both performance and code clarity. \ No newline at end of file diff --git a/DevProxy.Plugins/packages.lock.json b/DevProxy.Plugins/packages.lock.json index 10eb0fc6..6bf76b16 100644 --- a/DevProxy.Plugins/packages.lock.json +++ b/DevProxy.Plugins/packages.lock.json @@ -105,17 +105,6 @@ "Microsoft.IdentityModel.Tokens": "8.9.0" } }, - "Unobtanium.Web.Proxy": { - "type": "Direct", - "requested": "[0.1.5, )", - "resolved": "0.1.5", - "contentHash": "HiICGm0e44+i4aVHpLn+aphmSC2eQnDvlTttw1rE0hntOZKoLGRy37sydqqbRP1ZokMf3Mt0GEgSWxDwnucKGg==", - "dependencies": { - "BouncyCastle.Cryptography": "2.4.0", - "Microsoft.Extensions.Logging.Abstractions": "8.0.1", - "System.Runtime.CompilerServices.Unsafe": "6.0.0" - } - }, "Azure.Core": { "type": "Transitive", "resolved": "1.44.1", @@ -131,11 +120,6 @@ "System.Threading.Tasks.Extensions": "4.5.4" } }, - "BouncyCastle.Cryptography": { - "type": "Transitive", - "resolved": "2.4.0", - "contentHash": "SwXsAV3sMvAU/Nn31pbjhWurYSjJ+/giI/0n6tCrYoupEK34iIHCuk3STAd9fx8yudM85KkLSVdn951vTng/vQ==" - }, "Markdig": { "type": "Transitive", "resolved": "0.41.3", @@ -572,7 +556,6 @@ "Newtonsoft.Json.Schema": "[4.0.1, )", "Scriban": "[6.2.1, )", "System.CommandLine": "[2.0.0-beta5.25306.1, )", - "Unobtanium.Web.Proxy": "[0.1.5, )", "YamlDotNet": "[16.3.0, )" } } diff --git a/DevProxy.sln b/DevProxy.sln index 9a4c219d..b21d9fa7 100644 --- a/DevProxy.sln +++ b/DevProxy.sln @@ -7,6 +7,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DevProxy", "DevProxy\DevPro EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{1CC1EC7F-4839-43C8-9980-9BCB19609FA9}" ProjectSection(SolutionItems) = preProject + .editorconfig = .editorconfig LICENSE = LICENSE README.md = README.md settings.editorconfig = settings.editorconfig diff --git a/DevProxy/ApiControllers/ProxyController.cs b/DevProxy/ApiControllers/ProxyController.cs index cec82532..cb9db278 100644 --- a/DevProxy/ApiControllers/ProxyController.cs +++ b/DevProxy/ApiControllers/ProxyController.cs @@ -8,13 +8,14 @@ using System.ComponentModel.DataAnnotations; using DevProxy.Proxy; using DevProxy.Abstractions.Proxy; +using Unobtanium.Web.Proxy; namespace DevProxy.ApiControllers; [ApiController] [Route("[controller]")] #pragma warning disable CA1515 // required for the API controller -public sealed class ProxyController(IProxyStateController proxyStateController, IProxyConfiguration proxyConfiguration) : ControllerBase +public sealed class ProxyController(IProxyStateController proxyStateController, IProxyConfiguration proxyConfiguration, ICertificateManager certificateManager) : ControllerBase #pragma warning restore CA1515 { private readonly IProxyStateController _proxyStateController = proxyStateController; @@ -100,7 +101,7 @@ public IActionResult CreateJwtToken([FromBody] JwtOptions jwtOptions) } [HttpGet("rootCertificate")] - public IActionResult GetRootCertificate([FromQuery][Required] string format) + public async Task GetRootCertificateAsync([FromQuery][Required] string format, CancellationToken cancellationToken) { if (string.IsNullOrWhiteSpace(format)) { @@ -114,7 +115,7 @@ public IActionResult GetRootCertificate([FromQuery][Required] string format) return ValidationProblem(ModelState); } - var certificate = ProxyEngine.ProxyServer.CertificateManager.RootCertificate; + var certificate = await certificateManager.GetRootCertificateAsync(false, cancellationToken); if (certificate == null) { var problemDetails = new ProblemDetails diff --git a/DevProxy/Commands/CertCommand.cs b/DevProxy/Commands/CertCommand.cs index 06c05240..d46d0182 100644 --- a/DevProxy/Commands/CertCommand.cs +++ b/DevProxy/Commands/CertCommand.cs @@ -3,34 +3,36 @@ // See the LICENSE file in the project root for more information. using DevProxy.Abstractions.Utils; -using DevProxy.Proxy; using System.CommandLine; using System.CommandLine.Parsing; using System.Diagnostics; -using Titanium.Web.Proxy.Helpers; +using System.Runtime.InteropServices; +using Unobtanium.Web.Proxy; namespace DevProxy.Commands; sealed class CertCommand : Command { private readonly ILogger _logger; + private readonly ICertificateManager _certificateManager; private readonly Option _forceOption = new("--force", "-f") { Description = "Don't prompt for confirmation when removing the certificate" }; - public CertCommand(ILogger logger) : + public CertCommand(ILogger logger, ICertificateManager certificateManager) : base("cert", "Manage the Dev Proxy certificate") { _logger = logger; ConfigureCommand(); + _certificateManager = certificateManager; } private void ConfigureCommand() { var certEnsureCommand = new Command("ensure", "Ensure certificates are setup (creates root if required). Also makes root certificate trusted."); - certEnsureCommand.SetAction(async _ => await EnsureCertAsync()); + certEnsureCommand.SetAction(async (_, cancellationToken) => await EnsureCertAsync(cancellationToken)); var certRemoveCommand = new Command("remove", "Remove the certificate from Root Store"); certRemoveCommand.SetAction(RemoveCert); @@ -43,14 +45,19 @@ private void ConfigureCommand() }.OrderByName()); } - private async Task EnsureCertAsync() + private async Task EnsureCertAsync(CancellationToken cancellationToken) { _logger.LogTrace("EnsureCertAsync() called"); try { _logger.LogInformation("Ensuring certificate exists and is trusted..."); - await ProxyEngine.ProxyServer.CertificateManager.EnsureRootCertificateAsync(); + // TODO: Make the computer trust this certificate + _ = await _certificateManager.GetRootCertificateAsync(false, cancellationToken); + if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // ... + } _logger.LogInformation("DONE"); } catch (Exception ex) @@ -80,7 +87,7 @@ public void RemoveCert(ParseResult parseResult) _logger.LogInformation("Uninstalling the root certificate..."); RemoveTrustedCertificateOnMac(); - ProxyEngine.ProxyServer.CertificateManager.RemoveTrustedRootCertificate(machineTrusted: false); + // TODO: Implement for Windows/Linux _logger.LogInformation("DONE"); } @@ -118,7 +125,7 @@ private static bool PromptConfirmation(string message, bool acceptByDefault) private static void RemoveTrustedCertificateOnMac() { - if (!RunTime.IsMac) + if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { return; } diff --git a/DevProxy/DevProxy.csproj b/DevProxy/DevProxy.csproj index b7fa312b..890ce970 100644 --- a/DevProxy/DevProxy.csproj +++ b/DevProxy/DevProxy.csproj @@ -1,4 +1,4 @@ - + Exe @@ -15,7 +15,7 @@ README.md false true - true + false AllEnabledByDefault @@ -32,23 +32,24 @@ - + - - + - - + + + + - + diff --git a/DevProxy/Extensions/ILoggingBuilderExtensions.cs b/DevProxy/Extensions/ILoggingBuilderExtensions.cs index 24b40eb4..30047451 100644 --- a/DevProxy/Extensions/ILoggingBuilderExtensions.cs +++ b/DevProxy/Extensions/ILoggingBuilderExtensions.cs @@ -4,6 +4,7 @@ using DevProxy.Commands; using DevProxy.Logging; +using Microsoft.Extensions.DependencyInjection.Extensions; #pragma warning disable IDE0130 namespace Microsoft.Extensions.Logging; @@ -13,7 +14,7 @@ static class ILoggingBuilderExtensions { public static ILoggingBuilder AddRequestLogger(this ILoggingBuilder builder) { - _ = builder.Services.AddSingleton(); + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton()); return builder; } @@ -27,6 +28,11 @@ public static ILoggingBuilder ConfigureDevProxyLogging( configuration.GetValue("logLevel", LogLevel.Information); _ = builder + .AddOpenTelemetry(config => + { + config.IncludeFormattedMessage = true; + config.IncludeScopes = true; + }) .AddFilter("Microsoft.Hosting.*", LogLevel.Error) .AddFilter("Microsoft.AspNetCore.*", LogLevel.Error) .AddFilter("Microsoft.Extensions.*", LogLevel.Error) @@ -49,7 +55,8 @@ public static ILoggingBuilder ConfigureDevProxyLogging( } ) .AddRequestLogger() - .SetMinimumLevel(configuredLogLevel); + .SetMinimumLevel(configuredLogLevel) + ; return builder; } diff --git a/DevProxy/Extensions/IServiceCollectionExtensions.cs b/DevProxy/Extensions/IServiceCollectionExtensions.cs index 98bed44a..b0ecb954 100644 --- a/DevProxy/Extensions/IServiceCollectionExtensions.cs +++ b/DevProxy/Extensions/IServiceCollectionExtensions.cs @@ -5,9 +5,16 @@ using DevProxy; using DevProxy.Abstractions.Data; using DevProxy.Abstractions.LanguageModel; +using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Proxy; using DevProxy.Commands; +using DevProxy.Plugins; using DevProxy.Proxy; +using OpenTelemetry; +using OpenTelemetry.Metrics; +using OpenTelemetry.Trace; +using Unobtanium.Web.Proxy; +using Unobtanium.Web.Proxy.Services; #pragma warning disable IDE0130 namespace Microsoft.Extensions.DependencyInjection; @@ -22,6 +29,7 @@ public static IServiceCollection ConfigureDevProxyServices( { _ = services.AddControllers(); _ = services + .AddOpenTelemetryConfig(configuration) .AddApplicationServices(configuration, options) .AddHostedService() .AddEndpointsApiExplorer() @@ -37,11 +45,28 @@ static IServiceCollection AddApplicationServices( DevProxyConfigOptions options) { _ = services + .AddProxyHttpClientFactory() + .AddProxyEvents(new Unobtanium.Web.Proxy.Events.ProxyServerEvents()) + .Configure(options => + { + options.Port = configuration.GetValue("Port", ProxyServerDefaults.DEFAULT_PORT); + //options.TrustCertificateOnStart = true; // Automatically trust the certificate on start, NOT IMPLEMENTED YET!! + options.TrustCertificateOnStartAsUser = true; // Automatically trust the certificate on start as user, NOT IMPLEMENTED YET!! + }) + .Configure (certOptions => + { + certOptions.RootCertificateName = "Dev Proxy CA"; + certOptions.CachePath = configuration.GetValue("DEV_PROXY_CERT_PATH"); + }) + .AddProxyServices() // This adds the background services for the proxy and adds the default ICertificateManager .AddSingleton((IConfigurationRoot)configuration) .AddSingleton() .AddSingleton() .AddSingleton() - .AddSingleton(sp => ProxyEngine.Certificate!) + .AddSingleton(sp => new ProxyStorage(sp.GetRequiredService())) + // TODO: Removed the injected certificate + //.AddSingleton(sp => ProxyEngine.Certificate!) // Why is this injected? + //.AddSingleton(sp => sp.GetRequiredService().CertificateManager.RootCertificate!) .AddSingleton(sp => LanguageModelClientFactory.Create(sp, configuration)) .AddSingleton() .AddSingleton() @@ -53,4 +78,50 @@ static IServiceCollection AddApplicationServices( return services; } + + static IServiceCollection AddProxyHttpClientFactory(this IServiceCollection services) + { + _ = services.AddHttpClient(EfficientProxyHttpClientFactory.HTTP_CLIENT_NAME, client => + { + // Configure the HttpClient as needed, e.g., set base address, default headers, etc. + //client.BaseAddress = new Uri("https://graph.microsoft.com/"); + }) + .ConfigurePrimaryHttpMessageHandler(() => + { + // Configure HttpClientHandler to explicitly bypass system proxy settings + return new HttpClientHandler() + { + UseProxy = false, // Explicitly disable proxy usage + Proxy = null // Ensure no proxy is set + }; + }); + _ = services.AddTransient(); + return services; + } + + static IServiceCollection AddOpenTelemetryConfig(this IServiceCollection services, IConfigurationRoot configuration) + { + var openTelemetryBuilder = services + .AddOpenTelemetry() + .WithMetrics(metrics => + { + _ = metrics + .AddHttpClientInstrumentation() + .AddRuntimeInstrumentation(); + + }).WithTracing(tracing => + { + _ = tracing + .AddSource(ProxyServerDefaults.ACTIVITY_SOURCE_NAME) + .AddSource(ProxyEngine.ACTIVITY_SOURCE_NAME) + .AddHttpClientInstrumentation() + ; + }); + var endpoint = configuration.GetValue("OTEL_EXPORTER_OTLP_ENDPOINT"); + if (!string.IsNullOrEmpty(endpoint)) + { + _ = openTelemetryBuilder.UseOtlpExporter(); + } + return services; + } } \ No newline at end of file diff --git a/DevProxy/Logging/ILoggerExtensions.cs b/DevProxy/Logging/ILoggerExtensions.cs index 53f1870c..882c9d32 100644 --- a/DevProxy/Logging/ILoggerExtensions.cs +++ b/DevProxy/Logging/ILoggerExtensions.cs @@ -15,4 +15,13 @@ static class ILoggerExtensions { nameof(url), url }, { nameof(requestId), requestId } }); + + public static IDisposable? BeginRequestScope(this ILogger logger, HttpMethod method, Uri url, string requestId) => + logger.BeginScope(new Dictionary + { + { nameof(method), method }, + { nameof(url), url }, + { nameof(requestId), requestId } + }); + } \ No newline at end of file diff --git a/DevProxy/Logging/ProxyConsoleFormatter.cs b/DevProxy/Logging/ProxyConsoleFormatter.cs index f1e3b3ad..498e8633 100644 --- a/DevProxy/Logging/ProxyConsoleFormatter.cs +++ b/DevProxy/Logging/ProxyConsoleFormatter.cs @@ -69,7 +69,7 @@ sealed class ProxyConsoleFormatter : ConsoleFormatter [MessageType.Timestamp] = (Console.BackgroundColor, ConsoleColor.Gray) }; - private readonly ConcurrentDictionary> _messages = []; + private readonly ConcurrentDictionary> _messages = []; private readonly ProxyConsoleFormatterOptions _options; private readonly HashSet _filteredMessageTypes; @@ -118,16 +118,20 @@ private void LogRequest(RequestLog requestLog, string category, IExternalScopePr if (messageType == MessageType.FinishedProcessingRequest) { - FlushLogsForRequest(requestId.Value, textWriter); + FlushLogsForRequest(requestId, textWriter); } else { - BufferRequestLog(requestLog, category, requestId.Value); + BufferRequestLog(requestLog, category, requestId); } } - private void FlushLogsForRequest(int requestId, TextWriter textWriter) + private void FlushLogsForRequest(string? requestId, TextWriter textWriter) { + if (string.IsNullOrEmpty(requestId) || textWriter is null) + { + return; + } if (!_messages.TryGetValue(requestId, out var messages)) { return; @@ -153,8 +157,12 @@ private void FlushLogsForRequest(int requestId, TextWriter textWriter) _ = _messages.TryRemove(requestId, out _); } - private void BufferRequestLog(RequestLog requestLog, string category, int requestId) + private void BufferRequestLog(RequestLog requestLog, string category, string? requestId) { + if (string.IsNullOrEmpty(requestId)) + { + return; + } requestLog.PluginName = category == DefaultCategoryName ? null : category; var messages = _messages.GetOrAdd(requestId, _ => []); messages.Add(requestLog); @@ -170,7 +178,7 @@ private void LogRegularLogMessage(in LogEntry logEntry, IExterna else { var message = LogEntry.FromLogEntry(logEntry); - var messages = _messages.GetOrAdd(requestId.Value, _ => []); + var messages = _messages.GetOrAdd(requestId, _ => []); messages.Add(message); } } @@ -289,15 +297,15 @@ private static string GetMessageTypeString(MessageType messageType) => private static (ConsoleColor bg, ConsoleColor fg) GetMessageTypeColor(MessageType messageType) => _messageTypeColors.TryGetValue(messageType, out var color) ? color : (Console.BackgroundColor, Console.ForegroundColor); - private static int? GetRequestIdScope(IExternalScopeProvider? scopeProvider) + private static string? GetRequestIdScope(IExternalScopeProvider? scopeProvider) { - int? requestId = null; + string? requestId = null; scopeProvider?.ForEachScope((scope, _) => { if (scope is Dictionary dictionary && dictionary.TryGetValue(nameof(requestId), out var req)) { - requestId = (int)req; + requestId = $"{req}"; } }, ""); return requestId; diff --git a/DevProxy/Logging/RequestLogger.cs b/DevProxy/Logging/RequestLogger.cs index fa4c1337..9577f7b6 100644 --- a/DevProxy/Logging/RequestLogger.cs +++ b/DevProxy/Logging/RequestLogger.cs @@ -35,12 +35,12 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except // Lazily resolve plugins to avoid circular dependency var plugins = _serviceProvider.GetRequiredService>(); - foreach (var plugin in plugins.Where(p => p.Enabled)) + foreach (var plugin in plugins.Where(p => p.Enabled && p.HandleRequestLogAsync is not null)) { // we don't have the app's cancellation token in the current // implementation, but should it change in the future, // we won't have to break the interface - joinableTaskFactory.Run(async () => await plugin.AfterRequestLogAsync(requestLogArgs, CancellationToken.None)); + joinableTaskFactory.Run(async () => await plugin.HandleRequestLogAsync!(requestLogArgs, CancellationToken.None)); } } } diff --git a/DevProxy/Plugins/ProxyStorage.cs b/DevProxy/Plugins/ProxyStorage.cs new file mode 100644 index 00000000..16364873 --- /dev/null +++ b/DevProxy/Plugins/ProxyStorage.cs @@ -0,0 +1,24 @@ +using DevProxy.Abstractions.Plugins; +using DevProxy.Proxy; +using System.Collections.Concurrent; + +namespace DevProxy.Plugins; + +/// +/// Default implementation of . +/// +internal class ProxyStorage : IProxyStorage +{ + internal ProxyStorage(IProxyState proxyState) + { + GlobalData = proxyState.GlobalData ?? throw new ArgumentException("GlobalData cannot be null.", nameof(proxyState)); + } + public Dictionary GlobalData { get; private set; } + + //Dictionary IProxyStorage.GlobalData => throw new NotImplementedException(); + + public Dictionary GetRequestData(RequestId id) => _requestData.TryGetValue(id, out var data) ? data : []; + public void RemoveRequestData(RequestId id) => _requestData.Remove(id, out _); + + private readonly ConcurrentDictionary> _requestData = []; +} diff --git a/DevProxy/Program.cs b/DevProxy/Program.cs index 8a16ae1e..c254ea08 100644 --- a/DevProxy/Program.cs +++ b/DevProxy/Program.cs @@ -5,6 +5,8 @@ using DevProxy; using DevProxy.Commands; using System.Net; +using Unobtanium.Web.Proxy; +using Unobtanium.Web.Proxy.Services; static WebApplication BuildApplication(string[] args, DevProxyConfigOptions options) { @@ -13,6 +15,19 @@ static WebApplication BuildApplication(string[] args, DevProxyConfigOptions opti _ = builder.Configuration.ConfigureDevProxyConfig(options); _ = builder.Logging.ConfigureDevProxyLogging(builder.Configuration, options); _ = builder.Services.ConfigureDevProxyServices(builder.Configuration, options); + _ = builder.Services + .Configure(options => + { + options.Port = ProxyServerDefaults.DEFAULT_PORT; // Set the port for the proxy server + options.HttpsPort = ProxyServerDefaults.DEFAULT_HTTPS_PORT; + }) + .Configure(config => + { + config.CachePath = DevProxy.Abstractions.Utils.ProxyUtils.ReplacePathTokens( + builder.Configuration.GetValue("certificateCachePath", "certs")); + }) + .AddProxyEvents(new Unobtanium.Web.Proxy.Events.ProxyServerEvents()) + .AddProxyServices(); var defaultIpAddress = "127.0.0.1"; var ipAddress = options.IPAddress ?? diff --git a/DevProxy/Properties/launchSettings.json b/DevProxy/Properties/launchSettings.json index 30111fdf..2937063a 100755 --- a/DevProxy/Properties/launchSettings.json +++ b/DevProxy/Properties/launchSettings.json @@ -2,7 +2,9 @@ "profiles": { "No args": { "commandName": "Project", - "commandLineArgs": "", + "environmentVariables": { + "AsSystemProxy": "false" + }, "hotReloadEnabled": true }, "No chaos with mock responses": { @@ -32,7 +34,12 @@ }, "Default": { "commandName": "Project", - "hotReloadEnabled": true + "hotReloadEnabled": true, + "environmentVariables": { + "OTEL_EXPORTER_OTLP_ENDPOINT": "http://localhost:4317", + "OTEL_EXPORTER_OTLP_PROTOCOL": "grpc", + "AsSystemProxy": "false" + } }, "Missing arg": { "commandName": "Project", diff --git a/DevProxy/Proxy/CertificateDiskCache.cs b/DevProxy/Proxy/CertificateDiskCache.cs deleted file mode 100644 index e7442a63..00000000 --- a/DevProxy/Proxy/CertificateDiskCache.cs +++ /dev/null @@ -1,142 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Security.Cryptography.X509Certificates; -using Titanium.Web.Proxy.Certificates.Cache; -using Titanium.Web.Proxy.Helpers; - -namespace DevProxy.Proxy; - -// based on https://github.com/justcoding121/titanium-web-proxy/blob/9e71608d204e5b67085656dd6b355813929801e4/src/Titanium.Web.Proxy/Certificates/Cache/DefaultCertificateDiskCache.cs -internal sealed class CertificateDiskCache : ICertificateCache -{ - private const string DefaultCertificateDirectoryName = "crts"; - private const string DefaultCertificateFileExtension = ".pfx"; - private const string DefaultRootCertificateFileName = "rootCert" + DefaultCertificateFileExtension; - private const string ProxyConfigurationFolderName = "dev-proxy"; - - private string? rootCertificatePath; - - public Task LoadRootCertificateAsync(string pathOrName, string password, X509KeyStorageFlags storageFlags, CancellationToken cancellationToken) - { - var path = GetRootCertificatePath(pathOrName, false); - return Task.FromResult(LoadCertificate(path, password, storageFlags)); - } - - public async Task SaveRootCertificateAsync(string pathOrName, string password, X509Certificate2 certificate, CancellationToken cancellationToken) - { - var path = GetRootCertificatePath(pathOrName, true); - var exported = certificate.Export(X509ContentType.Pkcs12, password); - await File.WriteAllBytesAsync(path, exported, cancellationToken); - } - - public Task LoadCertificateAsync(string subjectName, X509KeyStorageFlags storageFlags, CancellationToken cancellationToken) - { - var filePath = Path.Combine(GetCertificatePath(false), subjectName + DefaultCertificateFileExtension); - return Task.FromResult(LoadCertificate(filePath, string.Empty, storageFlags)); - } - - public async Task SaveCertificateAsync(string subjectName, X509Certificate2 certificate, CancellationToken cancellationToken) - { - var filePath = Path.Combine(GetCertificatePath(true), subjectName + DefaultCertificateFileExtension); - var exported = certificate.Export(X509ContentType.Pkcs12); - await File.WriteAllBytesAsync(filePath, exported, cancellationToken); - } - - public void Clear() - { - try - { - var path = GetCertificatePath(false); - if (Directory.Exists(path)) - { - Directory.Delete(path, true); - } - } - catch (Exception) - { - // do nothing - } - } - - private string GetRootCertificatePath(string pathOrName, bool create) - { - if (Path.IsPathRooted(pathOrName)) - { - return pathOrName; - } - - return Path.Combine(GetRootCertificateDirectory(create), - string.IsNullOrEmpty(pathOrName) ? DefaultRootCertificateFileName : pathOrName); - } - - private string GetCertificatePath(bool create) - { - var path = GetRootCertificateDirectory(create); - - var certPath = Path.Combine(path, DefaultCertificateDirectoryName); - if (create && !Directory.Exists(certPath)) - { - _ = Directory.CreateDirectory(certPath); - } - - return certPath; - } - - private string GetRootCertificateDirectory(bool create) - { - if (rootCertificatePath == null) - { - if (RunTime.IsUwpOnWindows) - { - rootCertificatePath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), ProxyConfigurationFolderName); - } - else if (RunTime.IsLinux) - { - rootCertificatePath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), ProxyConfigurationFolderName); - } - else if (RunTime.IsMac) - { - rootCertificatePath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), ProxyConfigurationFolderName); - } - else - { - var assemblyLocation = AppContext.BaseDirectory; - - var path = Path.GetDirectoryName(assemblyLocation); - - rootCertificatePath = path ?? throw new InvalidOperationException("Unable to resolve root certificate directory path."); - } - } - - if (create && !Directory.Exists(rootCertificatePath)) - { - _ = Directory.CreateDirectory(rootCertificatePath); - } - - return rootCertificatePath; - } - - private static X509Certificate2? LoadCertificate(string path, string password, X509KeyStorageFlags storageFlags) - { - byte[] exported; - - if (!File.Exists(path)) - { - return null; - } - - try - { - exported = File.ReadAllBytes(path); - } - catch (IOException) - { - // file or directory not found - return null; - } - - return X509CertificateLoader.LoadPkcs12(exported, password, storageFlags); - } -} \ No newline at end of file diff --git a/DevProxy/Proxy/EfficientProxyHttpClientFactory.cs b/DevProxy/Proxy/EfficientProxyHttpClientFactory.cs new file mode 100644 index 00000000..37d4ca56 --- /dev/null +++ b/DevProxy/Proxy/EfficientProxyHttpClientFactory.cs @@ -0,0 +1,15 @@ +using Unobtanium.Web.Proxy; + +namespace DevProxy.Proxy; + +/// +/// for efficient re-use of available ports +/// +/// Is added to Dependency Injection +internal sealed class EfficientProxyHttpClientFactory(IHttpClientFactory httpClientFactory) : IProxyHttpClientFactory +{ + internal const string HTTP_CLIENT_NAME = "DevProxy.Proxy.EfficientHttpClient"; + private readonly IHttpClientFactory _httpClientFactory = httpClientFactory; + + public HttpClient CreateHttpClient(string host) => _httpClientFactory.CreateClient(HTTP_CLIENT_NAME); +} diff --git a/DevProxy/Proxy/ProxyEngine.cs b/DevProxy/Proxy/ProxyEngine.cs index 04c1498c..cb98aacc 100755 --- a/DevProxy/Proxy/ProxyEngine.cs +++ b/DevProxy/Proxy/ProxyEngine.cs @@ -5,17 +5,13 @@ using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Proxy; using DevProxy.Abstractions.Utils; -using Microsoft.VisualStudio.Threading; -using System.Collections.Concurrent; +using System; using System.Diagnostics; -using System.Net; -using System.Security.Cryptography.X509Certificates; +using System.Net.Http.Headers; +using System.Runtime.InteropServices; using System.Text.RegularExpressions; -using Titanium.Web.Proxy; -using Titanium.Web.Proxy.EventArguments; -using Titanium.Web.Proxy.Helpers; -using Titanium.Web.Proxy.Http; -using Titanium.Web.Proxy.Models; +using Unobtanium.Web.Proxy; +using Unobtanium.Web.Proxy.Events; namespace DevProxy.Proxy; @@ -30,14 +26,19 @@ sealed class ProxyEngine( IProxyConfiguration proxyConfiguration, ISet urlsToWatch, IProxyStateController proxyController, - ILogger logger) : BackgroundService, IDisposable + ILogger logger, + ProxyServerEvents proxyEvents, + ICertificateManager certificateManager, + IProxyStorage proxyStorage) : BackgroundService, IDisposable { + internal const string ACTIVITY_SOURCE_NAME = "DevProxy.Proxy.ProxyEngine"; + public static readonly ActivitySource ActivitySource = new(ACTIVITY_SOURCE_NAME); private readonly IEnumerable _plugins = plugins; private readonly ILogger _logger = logger; private readonly IProxyConfiguration _config = proxyConfiguration; - internal static ProxyServer ProxyServer { get; private set; } - private ExplicitProxyEndPoint? _explicitEndPoint; + //internal static ProxyServer ProxyServer { get; private set; } + //private ExplicitProxyEndPoint? _explicitEndPoint; // lists of URLs to watch, used for intercepting requests private readonly ISet _urlsToWatch = urlsToWatch; // lists of hosts to watch extracted from urlsToWatch, @@ -46,34 +47,34 @@ sealed class ProxyEngine( private readonly IProxyStateController _proxyController = proxyController; // Dictionary for plugins to store data between requests // the key is HashObject of the SessionEventArgs object - private readonly ConcurrentDictionary> _pluginData = []; + //private readonly ConcurrentDictionary> _pluginData = []; private InactivityTimer? _inactivityTimer; private CancellationToken? _cancellationToken; - public static X509Certificate2? Certificate => ProxyServer?.CertificateManager.RootCertificate; + //public static X509Certificate2? Certificate => proxyServer?.CertificateManager.RootCertificate; - private ExceptionHandler ExceptionHandler => ex => _logger.LogError(ex, "An error occurred in a plugin"); + //private ExceptionHandler ExceptionHandler => ex => _logger.LogError(ex, "An error occurred in a plugin"); - static ProxyEngine() - { - ProxyServer = new(); - ProxyServer.CertificateManager.PfxFilePath = Environment.GetEnvironmentVariable("DEV_PROXY_CERT_PATH") ?? string.Empty; - ProxyServer.CertificateManager.RootCertificateName = "Dev Proxy CA"; - ProxyServer.CertificateManager.CertificateStorage = new CertificateDiskCache(); - // we need to change this to a value lower than 397 - // to avoid the ERR_CERT_VALIDITY_TOO_LONG error in Edge - ProxyServer.CertificateManager.CertificateValidDays = 365; - - using var joinableTaskContext = new JoinableTaskContext(); - var joinableTaskFactory = new JoinableTaskFactory(joinableTaskContext); - _ = joinableTaskFactory.Run(async () => await ProxyServer.CertificateManager.LoadOrCreateRootCertificateAsync()); - } + //static ProxyEngine() + //{ + // ProxyServer = new(); + // ProxyServer.CertificateManager.PfxFilePath = Environment.GetEnvironmentVariable("DEV_PROXY_CERT_PATH") ?? string.Empty; + // ProxyServer.CertificateManager.RootCertificateName = "Dev Proxy CA"; + // ProxyServer.CertificateManager.CertificateStorage = new CertificateDiskCache(); + // // we need to change this to a value lower than 397 + // // to avoid the ERR_CERT_VALIDITY_TOO_LONG error in Edge + // ProxyServer.CertificateManager.CertificateValidDays = 365; + + // using var joinableTaskContext = new JoinableTaskContext(); + // var joinableTaskFactory = new JoinableTaskFactory(joinableTaskContext); + // _ = joinableTaskFactory.Run(async () => await ProxyServer.CertificateManager.LoadOrCreateRootCertificateAsync()); + //} protected override async Task ExecuteAsync(CancellationToken stoppingToken) { _cancellationToken = stoppingToken; - Debug.Assert(ProxyServer is not null, "Proxy server is not initialized"); + Debug.Assert(proxyEvents is not null, "Proxy server is not initialized"); if (!_urlsToWatch.Any()) { @@ -83,46 +84,66 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) LoadHostNamesFromUrls(); - ProxyServer.BeforeRequest += OnRequestAsync; - ProxyServer.BeforeResponse += OnBeforeResponseAsync; - ProxyServer.AfterResponse += OnAfterResponseAsync; - ProxyServer.ServerCertificateValidationCallback += OnCertificateValidationAsync; - ProxyServer.ClientCertificateSelectionCallback += OnCertificateSelectionAsync; + // TODO: Handle replacement of BeforeRequest + //proxyServer.BeforeRequest += OnRequestAsync; + proxyEvents.OnRequest += OnRequestAsync; + + // TODO: Handle removal of BeforeResponse + //proxyServer.BeforeResponse += OnBeforeResponseAsync; + + // TODO: Handle replacement of AfterResponse + //proxyServer.AfterResponse += OnAfterResponseAsync; + proxyEvents.OnResponse += OnResponseAsync; + + //proxyServer.ServerCertificateValidationCallback += OnCertificateValidationAsync; + //proxyServer.ClientCertificateSelectionCallback += OnCertificateSelectionAsync; - var ipAddress = string.IsNullOrEmpty(_config.IPAddress) ? IPAddress.Any : IPAddress.Parse(_config.IPAddress); - _explicitEndPoint = new(ipAddress, _config.Port, true); + // Endpoint is configured in IServiceCollectionExtensions.AddProxyConfiguration + //var ipAddress = string.IsNullOrEmpty(_config.IPAddress) ? IPAddress.Any : IPAddress.Parse(_config.IPAddress); + //_explicitEndPoint = new(ipAddress, _config.Port, true); + + // TODO: Implement process validation // Fired when a CONNECT request is received - _explicitEndPoint.BeforeTunnelConnectRequest += OnBeforeTunnelConnectRequestAsync; + //_explicitEndPoint.BeforeTunnelConnectRequest += OnBeforeTunnelConnectRequestAsync; + // This is superceeded by: + proxyEvents.ShouldDecryptNewConnection = (host, client, cts) => Task.FromResult(IsProxiedHost(host));// || IsProxiedProcess(...)); if (_config.InstallCert) { - await ProxyServer.CertificateManager.EnsureRootCertificateAsync(stoppingToken); + _ = await certificateManager.GetRootCertificateAsync(false, stoppingToken); + // TODO: Execute code to trust certificate } else { - _explicitEndPoint.GenericCertificate = await ProxyServer - .CertificateManager - .LoadRootCertificateAsync(stoppingToken); + // TODO: Remove this code, happens automatically + //_explicitEndPoint.GenericCertificate = await proxyServer + // .CertificateManager + // .LoadRootCertificateAsync(stoppingToken); } - ProxyServer.AddEndPoint(_explicitEndPoint); - await ProxyServer.StartAsync(cancellationToken: stoppingToken); + //proxyServer.AddEndPoint(_explicitEndPoint); + //await proxyServer.StartAsync(cancellationToken: stoppingToken); // run first-run setup on macOS FirstRunSetup(); - foreach (var endPoint in ProxyServer.ProxyEndPoints) - { - _logger.LogInformation("Dev Proxy listening on {IPAddress}:{Port}...", endPoint.IpAddress, endPoint.Port); - } + //ExplicitProxyEndPoint? explicitProxyEndPoint = null; + + //foreach (var endPoint in proxyServer.ProxyEndPoints) + //{ + // _logger.LogInformation("Dev Proxy listening on {IPAddress}:{Port}...", endPoint.IpAddress, endPoint.Port); + // if (explicitProxyEndPoint is null && endPoint is ExplicitProxyEndPoint explicitProxyEnd) + // { + // explicitProxyEndPoint = explicitProxyEnd; + // } + //} if (_config.AsSystemProxy) { - if (RunTime.IsWindows) + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - ProxyServer.SetAsSystemHttpProxy(_explicitEndPoint); - ProxyServer.SetAsSystemHttpsProxy(_explicitEndPoint); + //TODO: Implement Windows system proxy toggle } - else if (RunTime.IsMac) + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { ToggleSystemProxy(ToggleSystemProxyAction.On, _config.IPAddress, _config.Port); } @@ -162,7 +183,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) try { - while (!stoppingToken.IsCancellationRequested && ProxyServer.ProxyRunning) + while (!stoppingToken.IsCancellationRequested) { while (!Console.KeyAvailable) { @@ -180,7 +201,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) private void FirstRunSetup() { - if (!RunTime.IsMac || + if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX) || _config.NoFirstRun || !HasRunFlag.CreateIfMissing() || !_config.InstallCert) @@ -294,28 +315,15 @@ private void StopProxy() // Unsubscribe & Quit try { - if (_explicitEndPoint != null) - { - _explicitEndPoint.BeforeTunnelConnectRequest -= OnBeforeTunnelConnectRequestAsync; - } - - if (ProxyServer is not null) - { - ProxyServer.BeforeRequest -= OnRequestAsync; - ProxyServer.BeforeResponse -= OnBeforeResponseAsync; - ProxyServer.AfterResponse -= OnAfterResponseAsync; - ProxyServer.ServerCertificateValidationCallback -= OnCertificateValidationAsync; - ProxyServer.ClientCertificateSelectionCallback -= OnCertificateSelectionAsync; - - if (ProxyServer.ProxyRunning) - { - ProxyServer.Stop(); - } - } + //if (_explicitEndPoint != null) + //{ + // _explicitEndPoint.BeforeTunnelConnectRequest -= OnBeforeTunnelConnectRequestAsync; + //} + // proxyServer is stopped automatically when the service is stopped _inactivityTimer?.Stop(); - if (RunTime.IsMac && _config.AsSystemProxy) + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX) && _config.AsSystemProxy) { ToggleSystemProxy(ToggleSystemProxyAction.Off); } @@ -334,18 +342,7 @@ public override async Task StopAsync(CancellationToken cancellationToken) await base.StopAsync(cancellationToken); } - async Task OnBeforeTunnelConnectRequestAsync(object sender, TunnelConnectSessionEventArgs e) - { - // Ensures that only the targeted Https domains are proxyied - if (!IsProxiedHost(e.HttpClient.Request.RequestUri.Host) || - !IsProxiedProcess(e)) - { - e.DecryptSsl = false; - } - await Task.CompletedTask; - } - - private bool IsProxiedProcess(TunnelConnectSessionEventArgs e) + private bool IsProxiedProcess(ClientDetails clientDetails) { // If no process names or IDs are specified, we proxy all processes if (!_config.WatchPids.Any() && @@ -354,14 +351,13 @@ private bool IsProxiedProcess(TunnelConnectSessionEventArgs e) return true; } - var processId = GetProcessId(e); + var processId = GetProcessId(clientDetails); if (processId == -1) { return false; } - if (_config.WatchPids.Any() && - _config.WatchPids.Contains(processId)) + if (_config.WatchPids.Contains(processId)) { return true; } @@ -378,69 +374,123 @@ private bool IsProxiedProcess(TunnelConnectSessionEventArgs e) return false; } - async Task OnRequestAsync(object sender, SessionEventArgs e) + async Task OnRequestAsync(object _, RequestEventArguments requestEventArguments, CancellationToken cancellationToken) { _inactivityTimer?.Reset(); - if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host) && - IsIncludedByHeaders(e.HttpClient.Request.Headers)) + if (IsProxiedHost(requestEventArguments.Request.RequestUri!.Host) && + IsIncludedByHeaders(requestEventArguments.Request.Headers)) { - if (!_pluginData.TryAdd(e.GetHashCode(), [])) - { - throw new InvalidOperationException($"Unable to initialize the plugin data storage for hash key {e.GetHashCode()}"); - } - var responseState = new ResponseState(); - var proxyRequestArgs = new ProxyRequestArgs(e, responseState) - { - SessionData = _pluginData[e.GetHashCode()], - GlobalData = _proxyController.ProxyState.GlobalData - }; - if (!proxyRequestArgs.HasRequestUrlMatch(_urlsToWatch)) - { - return; - } + //if (!_pluginData.TryAdd(requestEventArguments.RequestId, [])) + //{ + // throw new InvalidOperationException($"Unable to initialize the plugin data storage for hash key {requestEventArguments.RequestId}"); + //} + + using var scope = _logger.BeginRequestScope(requestEventArguments.Request.Method, requestEventArguments.Request.RequestUri, requestEventArguments.RequestId); + _logger.LogRequest($"{requestEventArguments.Request.Method} {requestEventArguments.Request.RequestUri}", MessageType.InterceptedRequest, requestEventArguments.Request, requestEventArguments.RequestId); + _logger.LogRequest($"{DateTimeOffset.UtcNow}", MessageType.Timestamp, requestEventArguments.Request, requestEventArguments.RequestId); - // we need to keep the request body for further processing - // by plugins - e.HttpClient.Request.KeepBody = true; - if (e.HttpClient.Request.HasBody) + if (!ProxyUtils.MatchesUrlToWatch(_urlsToWatch, requestEventArguments.Request.RequestUri.AbsoluteUri)) { - _ = await e.GetRequestBodyAsString(); + return RequestEventResponse.ContinueResponse(); } - using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); + //if (!_pluginData.TryAdd(requestEventArguments.RequestId, [])) + //{ + // // Throwing here will break the request.... + // throw new InvalidOperationException($"Unable to initialize the plugin data storage for hash key {requestEventArguments.RequestId}"); + //} - e.UserData = e.HttpClient.Request; + //var loggingContext = new LoggingContext(e); - var loggingContext = new LoggingContext(e); - _logger.LogRequest($"{e.HttpClient.Request.Method} {e.HttpClient.Request.Url}", MessageType.InterceptedRequest, loggingContext); - _logger.LogRequest($"{DateTimeOffset.UtcNow}", MessageType.Timestamp, loggingContext); - - await HandleRequestAsync(e, proxyRequestArgs); + return await HandleRequestAsync(requestEventArguments, cancellationToken); } + + return RequestEventResponse.ContinueResponse(); } - private async Task HandleRequestAsync(SessionEventArgs e, ProxyRequestArgs proxyRequestArgs) + private async Task HandleRequestAsync(RequestEventArguments arguments, CancellationToken cancellationToken) { - foreach (var plugin in _plugins.Where(p => p.Enabled)) - { - _cancellationToken?.ThrowIfCancellationRequested(); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationToken ?? CancellationToken.None); + using var scope = _logger.BeginRequestScope(arguments.Request.Method, arguments.Request.RequestUri!, arguments.RequestId); + _logger.LogRequest($"{arguments.Request.Method} {arguments.Request.RequestUri}", MessageType.InterceptedRequest, arguments.Request, arguments.RequestId); + + // Plugins that don't modify the request but provide guidance on them + // can be called in parallel, because they don't affect each other. + var guidancePlugins = _plugins.Where(p => p.Enabled && p.ProvideRequestGuidanceAsync is not null); + if (guidancePlugins.Any()) + { + var logArguments = new RequestArguments(arguments.Request, arguments.RequestId); + // Call OnRequestLogAsync for all plugins at the same time and wait for all of them to complete + var logTasks = guidancePlugins + .Select(plugin => plugin.ProvideRequestGuidanceAsync!(logArguments, cts.Token)) + .ToArray(); + try + { + await Task.WhenAll(logTasks); + } + catch (Exception ex) + { + //_logger.LogRequest(ex.Message, MessageType.Failed, arguments.Request, arguments.RequestId); + _logger.LogError(ex, "An error occurred in a plugin while logging request {RequestMethod} {RequestUrl}", + arguments.Request.Method, arguments.Request.RequestUri); + } + } + HttpResponseMessage? response = null; + HttpRequestMessage? request = null; + foreach (var plugin in _plugins + .Where(p => + p.Enabled + && p.OnRequestAsync is not null)) // Only plugins that have OnRequestAsync defined, maybe pre-select matches based on url? + { + cts.Token.ThrowIfCancellationRequested(); try { - await plugin.BeforeRequestAsync(proxyRequestArgs, _cancellationToken ?? CancellationToken.None); + var result = await plugin.OnRequestAsync!(new RequestArguments(arguments.Request, arguments.RequestId), cts.Token); + if (result is not null) + { + if (result.Request is not null) + { + request = result.Request; + // TODO: Decide what to do in this case, continue processing or return the request? + } + else if (result.Response is not null) + { + response = result.Response; + // Plugins no longer have to check if the response is already been set. + // If a plugin sets a response, it is expected to be the final response. + break; + } + } } catch (Exception ex) { - ExceptionHandler(ex); + _logger.LogError(ex, "An error occurred in plugin {PluginName} while processing request {RequestMethod} {RequestUrl}", + plugin.Name, arguments.Request.Method, arguments.Request.RequestUri); + } } // We only need to set the proxy header if the proxy has not set a response and the request is going to be sent to the target. - if (!proxyRequestArgs.ResponseState.HasBeenSet) + if (response is not null) { - _logger?.LogRequest("Passed through", MessageType.PassedThrough, new(e)); - AddProxyHeader(e.HttpClient.Request); + proxyStorage.RemoveRequestData(arguments.RequestId); + _logger.LogRequest($"{(int)response.StatusCode} {response.StatusCode}", MessageType.Mocked, arguments.Request, arguments.RequestId); + _logger.LogRequest("Done", MessageType.FinishedProcessingRequest, arguments.Request, arguments.RequestId, response); + return RequestEventResponse.EarlyResponse(response); } + else if (request is not null) + { + // If the request is modified, we need to add the Via header + AddProxyHeader(request); + _logger.LogRequest("Request modified by plugin", MessageType.InterceptedRequest, arguments.Request, arguments.RequestId); + //_logger.LogRequest($"{arguments.Request.Method} {arguments.Request.RequestUri}", MessageType.Processed, arguments.Request, arguments.RequestId); + // We can return the request to be sent to the target + return RequestEventResponse.ModifyRequest(request); + } + // If no plugins modified the request, we add the Via header to the original request + AddProxyHeader(arguments.Request); + return RequestEventResponse.ModifyRequest(arguments.Request); } private bool IsProxiedHost(string hostName) @@ -449,7 +499,7 @@ private bool IsProxiedHost(string hostName) return urlMatch is not null && !urlMatch.Exclude; } - private bool IsIncludedByHeaders(HeaderCollection requestHeaders) + private bool IsIncludedByHeaders(HttpRequestHeaders requestHeaders) { if (_config.FilterByHeaders is null) { @@ -463,7 +513,7 @@ private bool IsIncludedByHeaders(HeaderCollection requestHeaders) string.IsNullOrEmpty(header.Value) ? "(any)" : header.Value ); - if (requestHeaders.HeaderExists(header.Name)) + if (requestHeaders.Contains(header.Name)) { if (string.IsNullOrEmpty(header.Value)) { @@ -471,7 +521,7 @@ private bool IsIncludedByHeaders(HeaderCollection requestHeaders) return true; } - if (requestHeaders.GetHeaders(header.Name)!.Any(h => h.Value.Contains(header.Value, StringComparison.OrdinalIgnoreCase))) + if (requestHeaders.Any(h => h.Key.Equals(header.Name, StringComparison.OrdinalIgnoreCase) && (h.Value.ToString()?.Equals(header.Value, StringComparison.OrdinalIgnoreCase) ?? false))) { _logger.LogDebug("Request header {Header} contains value {Value}", header.Name, header.Value); return true; @@ -487,111 +537,197 @@ private bool IsIncludedByHeaders(HeaderCollection requestHeaders) return false; } - // Modify response - async Task OnBeforeResponseAsync(object sender, SessionEventArgs e) + //// Modify response + //// OnBeforeResponseAsync is no longer supported, where was this used for? + //async Task OnBeforeResponseAsync(object sender, SessionEventArgs e) + //{ + // // read response headers + // if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) + // { + // var proxyResponseArgs = new ProxyResponseArgs(e, new()) + // { + // SessionData = _pluginData[e.GetHashCode()], + // GlobalData = _proxyController.ProxyState.GlobalData + // }; + // if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) + // { + // return; + // } + + // using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); + + // // necessary to make the response body available to plugins + // e.HttpClient.Response.KeepBody = true; + // if (e.HttpClient.Response.HasBody) + // { + // _ = await e.GetResponseBody(); + // } + + // foreach (var plugin in _plugins.Where(p => p.Enabled)) + // { + // _cancellationToken?.ThrowIfCancellationRequested(); + + // try + // { + // await plugin.BeforeResponseAsync(proxyResponseArgs, _cancellationToken ?? CancellationToken.None); + // } + // catch (Exception ex) + // { + // ExceptionHandler(ex); + // } + // } + // } + //} + + //async Task OnAfterResponseAsync(object sender, SessionEventArgs e) + //{ + // // read response headers + // if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) + // { + // var proxyResponseArgs = new ProxyResponseArgs(e, new()) + // { + // SessionData = _pluginData[e.GetHashCode()], + // GlobalData = _proxyController.ProxyState.GlobalData + // }; + // if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) + // { + // // clean up + // _ = _pluginData.Remove(e.GetHashCode(), out _); + // return; + // } + + // // necessary to repeat to make the response body + // // of mocked requests available to plugins + // e.HttpClient.Response.KeepBody = true; + + // using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); + + // var message = $"{e.HttpClient.Request.Method} {e.HttpClient.Request.Url}"; + // var loggingContext = new LoggingContext(e); + // _logger.LogRequest(message, MessageType.InterceptedResponse, loggingContext); + + // foreach (var plugin in _plugins.Where(p => p.Enabled)) + // { + // _cancellationToken?.ThrowIfCancellationRequested(); + + // try + // { + // await plugin.AfterResponseAsync(proxyResponseArgs, _cancellationToken ?? CancellationToken.None); + // } + // catch (Exception ex) + // { + // ExceptionHandler(ex); + // } + // } + + // _logger.LogRequest(message, MessageType.FinishedProcessingRequest, loggingContext); + + // // clean up + // _ = _pluginData.Remove(e.GetHashCode(), out _); + // } + //} + + // Unobtanium ResponseHandler + private async Task OnResponseAsync(object sender, ResponseEventArguments arguments, CancellationToken cancellationToken) { - // read response headers - if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) + try { - var proxyResponseArgs = new ProxyResponseArgs(e, new()) - { - SessionData = _pluginData[e.GetHashCode()], - GlobalData = _proxyController.ProxyState.GlobalData - }; - if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) + return await OnResponseInternalAsync(sender, arguments, cancellationToken); + } + catch (OperationCanceledException) + { + _logger.LogInformation("Request was cancelled before response completed"); + return ResponseEventResponse.ContinueResponse(); + } + catch (Exception ex) + { + _logger.LogError(ex, "An error occurred while processing response {ResponseStatusCode} for request {RequestMethod} {RequestUrl}", + arguments.Response.StatusCode, arguments.Request.Method, arguments.Request.RequestUri); + return ResponseEventResponse.ContinueResponse(); + } + } + private async Task OnResponseInternalAsync(object _, ResponseEventArguments arguments, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + _logger.LogDebug("Request was cancelled before response completed"); + _logger.LogRequest($"{arguments.Request.Method} {arguments.Request.RequestUri}", MessageType.Failed, arguments.Request, arguments.RequestId); + return ResponseEventResponse.ContinueResponse(); + } + // Distributed tracing + using var activity = ActivitySource.StartActivity(nameof(OnResponseAsync), ActivityKind.Consumer, arguments.RequestActivity?.Context ?? default); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cancellationToken ?? CancellationToken.None); + var uri = arguments.Request.RequestUri!; + + using var scope = _logger.BeginRequestScope(arguments.Request.Method, uri, arguments.RequestId); + _logger.LogRequest($"Server response: {(int)arguments.Response.StatusCode} {arguments.Response.StatusCode}", MessageType.Normal, arguments.Request, arguments.RequestId); + HttpResponseMessage? response = null; + var guidancePlugins = _plugins.Where(p => p.Enabled && p.ProvideResponseGuidanceAsync is not null); + if (guidancePlugins.Any()) + { + // Call OnResponseLogAsync for all plugins at the same time and wait for all of them to complete + var logArguments = new ResponseArguments(arguments.Request, arguments.Response, arguments.RequestId); + var logTasks = guidancePlugins + .Select(plugin => plugin.ProvideResponseGuidanceAsync!(logArguments, cts.Token)) + .ToArray(); + try { - return; + await Task.WhenAll(logTasks); } - - using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); - - // necessary to make the response body available to plugins - e.HttpClient.Response.KeepBody = true; - if (e.HttpClient.Response.HasBody) + catch (OperationCanceledException) { - _ = await e.GetResponseBody(); + _logger.LogDebug("Request was cancelled before response completed"); } - - foreach (var plugin in _plugins.Where(p => p.Enabled)) + catch (Exception ex) { - _cancellationToken?.ThrowIfCancellationRequested(); - - try - { - await plugin.BeforeResponseAsync(proxyResponseArgs, _cancellationToken ?? CancellationToken.None); - } - catch (Exception ex) - { - ExceptionHandler(ex); - } + _logger.LogError(ex, "An error occurred in a plugin while logging response {ResponseStatusCode} for request {RequestMethod} {RequestUrl}", + arguments.Response.StatusCode, arguments.Request.Method, uri); } } - } - async Task OnAfterResponseAsync(object sender, SessionEventArgs e) - { - // read response headers - if (IsProxiedHost(e.HttpClient.Request.RequestUri.Host)) - { - var proxyResponseArgs = new ProxyResponseArgs(e, new()) - { - SessionData = _pluginData[e.GetHashCode()], - GlobalData = _proxyController.ProxyState.GlobalData - }; - if (!proxyResponseArgs.HasRequestUrlMatch(_urlsToWatch)) - { - // clean up - _ = _pluginData.Remove(e.GetHashCode(), out _); - return; - } - - // necessary to repeat to make the response body - // of mocked requests available to plugins - e.HttpClient.Response.KeepBody = true; - using var scope = _logger.BeginScope(e.HttpClient.Request.Method ?? "", e.HttpClient.Request.Url, e.GetHashCode()); - - var message = $"{e.HttpClient.Request.Method} {e.HttpClient.Request.Url}"; - var loggingContext = new LoggingContext(e); - _logger.LogRequest(message, MessageType.InterceptedResponse, loggingContext); + foreach (var plugin in _plugins.Where(p => p.Enabled && p.OnResponseAsync is not null)) + { + cts.Token.ThrowIfCancellationRequested(); - foreach (var plugin in _plugins.Where(p => p.Enabled)) + try { - _cancellationToken?.ThrowIfCancellationRequested(); - - try - { - await plugin.AfterResponseAsync(proxyResponseArgs, _cancellationToken ?? CancellationToken.None); - } - catch (Exception ex) + var result = await plugin.OnResponseAsync!(new ResponseArguments(arguments.Request, response ?? arguments.Response, arguments.RequestId), cts.Token); + if (result is not null) { - ExceptionHandler(ex); + if (result.Request is not null) + { + // If the plugin modified the request, it is a mistake. Faulty behavior. + _logger.LogError("Plugin {PluginName} tried changing the request", plugin.Name); + } + if (result.Response is not null) + { + response = result.Response; + // Maybe exit the loop here? + } } } - - _logger.LogRequest(message, MessageType.FinishedProcessingRequest, loggingContext); - - // clean up - _ = _pluginData.Remove(e.GetHashCode(), out _); + catch (Exception ex) + { + _logger.LogError(ex, "An error occurred in plugin {PluginName} while processing response {ResponseStatusCode} for request {RequestMethod} {RequestUrl}", + plugin.Name, arguments.Response.StatusCode, arguments.Request.Method, uri); + } } - } - - // Allows overriding default certificate validation logic - Task OnCertificateValidationAsync(object sender, CertificateValidationEventArgs e) - { - // set IsValid to true/false based on Certificate Errors - if (e.SslPolicyErrors == System.Net.Security.SslPolicyErrors.None) + if (response is not null) { - e.IsValid = true; + _logger.LogRequest($"Mocked response: {(int)response.StatusCode} {response.StatusCode}", MessageType.Mocked, arguments.Request, arguments.RequestId); } - - return Task.CompletedTask; + else + { + response = arguments.Response; + _logger.LogRequest($"Return unmodified response {(int)response.StatusCode} {response.StatusCode}", MessageType.Processed, arguments.Request, arguments.RequestId); + } + _logger.LogRequest("Done", MessageType.FinishedProcessingRequest, arguments.Request, arguments.RequestId); + proxyStorage.RemoveRequestData(arguments.RequestId); + return response is not null + ? ResponseEventResponse.ModifyResponse(response) + : ResponseEventResponse.ContinueResponse(); } - // Allows overriding default client certificate selection logic during mutual authentication - Task OnCertificateSelectionAsync(object sender, CertificateSelectionEventArgs e) => - // set e.clientCertificate to override - Task.CompletedTask; - private static void PrintHotkeys() { Console.WriteLine(""); @@ -624,17 +760,17 @@ private static void ToggleSystemProxy(ToggleSystemProxyAction toggle, string? ip process.WaitForExit(); } - private static int GetProcessId(TunnelConnectSessionEventArgs e) + private static int GetProcessId(ClientDetails e) { - if (RunTime.IsWindows) + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - return e.HttpClient.ProcessId.Value; + return -1; } var psi = new ProcessStartInfo { FileName = "lsof", - Arguments = $"-i :{e.ClientRemoteEndPoint?.Port}", + Arguments = $"-i :{e.Port}", UseShellExecute = false, RedirectStandardOutput = true, CreateNoWindow = true @@ -648,7 +784,7 @@ private static int GetProcessId(TunnelConnectSessionEventArgs e) proc.WaitForExit(); var lines = output.Split([Environment.NewLine], StringSplitOptions.RemoveEmptyEntries); - var matchingLine = lines.FirstOrDefault(l => l.Contains($"{e.ClientRemoteEndPoint?.Port}->", StringComparison.OrdinalIgnoreCase)); + var matchingLine = lines.FirstOrDefault(l => l.Contains($"{e.Port}->", StringComparison.OrdinalIgnoreCase)); if (matchingLine is null) { return -1; @@ -662,7 +798,7 @@ private static int GetProcessId(TunnelConnectSessionEventArgs e) return int.TryParse(pidString, out var pid) ? pid : -1; } - private static void AddProxyHeader(Request r) => r.Headers?.AddHeader("Via", $"{r.HttpVersion} dev-proxy/{ProxyUtils.ProductVersion}"); + private static void AddProxyHeader(HttpRequestMessage r) => r.Headers.TryAddWithoutValidation("Via", $"dev-proxy/{ProxyUtils.ProductVersion}"); public override void Dispose() { diff --git a/DevProxy/Proxy/ProxyStateController.cs b/DevProxy/Proxy/ProxyStateController.cs index 1f418abe..1f781403 100644 --- a/DevProxy/Proxy/ProxyStateController.cs +++ b/DevProxy/Proxy/ProxyStateController.cs @@ -4,7 +4,6 @@ using DevProxy.Abstractions.Plugins; using DevProxy.Abstractions.Proxy; -using Titanium.Web.Proxy; namespace DevProxy.Proxy; @@ -21,7 +20,6 @@ sealed class ProxyStateController( private readonly IEnumerable _plugins = plugins; private readonly IHostApplicationLifetime _hostApplicationLifetime = hostApplicationLifetime; private readonly ILogger _logger = logger; - private ExceptionHandler ExceptionHandler => ex => _logger.LogError(ex, "An error occurred in a plugin"); public void StartRecording() { @@ -51,17 +49,16 @@ public async Task StopRecordingAsync(CancellationToken cancellationToken) ProxyState.RequestLogs.Clear(); var recordingArgs = new RecordingArgs(clonedLogs) { - GlobalData = ProxyState.GlobalData }; - foreach (var plugin in _plugins.Where(p => p.Enabled)) + foreach (var plugin in _plugins.Where(p => p.Enabled && p.HandleRecordingStopAsync is not null)) { try { - await plugin.AfterRecordingStopAsync(recordingArgs, cancellationToken); + await plugin.HandleRecordingStopAsync!(recordingArgs, cancellationToken); } catch (Exception ex) { - ExceptionHandler(ex); + _logger.LogError(ex, "Error in plugin {PluginName} after recording stop", plugin.Name); } } _logger.LogInformation("DONE"); @@ -79,7 +76,7 @@ public async Task MockRequestAsync(CancellationToken cancellationToken) } catch (Exception ex) { - ExceptionHandler(ex); + _logger.LogError(ex, "Error in plugin {PluginName} after mock request", plugin.Name); } } } diff --git a/DevProxy/packages.lock.json b/DevProxy/packages.lock.json index 69c28690..60dad098 100644 --- a/DevProxy/packages.lock.json +++ b/DevProxy/packages.lock.json @@ -4,15 +4,14 @@ "net9.0": { "Azure.Identity": { "type": "Direct", - "requested": "[1.13.2, )", - "resolved": "1.13.2", - "contentHash": "CngQVQELdzFmsGSWyGIPIUOCrII7nApMVWxVmJCKQQrWxRXcNquCsZ+njRJRnhFUfD+KMAhpjyRCaceE4EOL6A==", + "requested": "[1.14.2, )", + "resolved": "1.14.2", + "contentHash": "YhNMwOTwT+I2wIcJKSdP0ADyB2aK+JaYWZxO8LSRDm5w77LFr0ykR9xmt2ZV5T1gaI7xU6iNFIh/yW1dAlpddQ==", "dependencies": { - "Azure.Core": "1.44.1", - "Microsoft.Identity.Client": "4.67.2", - "Microsoft.Identity.Client.Extensions.Msal": "4.67.2", - "System.Memory": "4.5.5", - "System.Threading.Tasks.Extensions": "4.5.4" + "Azure.Core": "1.46.1", + "Microsoft.Identity.Client": "4.73.1", + "Microsoft.Identity.Client.Extensions.Msal": "4.73.1", + "System.Memory": "4.5.5" } }, "Microsoft.EntityFrameworkCore.Sqlite": { @@ -31,25 +30,6 @@ "System.Text.Json": "9.0.4" } }, - "Microsoft.Extensions.Configuration": { - "type": "Direct", - "requested": "[9.0.4, )", - "resolved": "9.0.4", - "contentHash": "KIVBrMbItnCJDd1RF4KEaE8jZwDJcDUJW5zXpbwQ05HNYTK1GveHxHK0B3SjgDJuR48GRACXAO+BLhL8h34S7g==", - "dependencies": { - "Microsoft.Extensions.Configuration.Abstractions": "9.0.4", - "Microsoft.Extensions.Primitives": "9.0.4" - } - }, - "Microsoft.Extensions.Configuration.Binder": { - "type": "Direct", - "requested": "[9.0.4, )", - "resolved": "9.0.4", - "contentHash": "cdrjcl9RIcwt3ECbnpP0Gt1+pkjdW90mq5yFYy8D9qRj2NqFFcv3yDp141iEamsd9E218sGxK8WHaIOcrqgDJg==", - "dependencies": { - "Microsoft.Extensions.Configuration.Abstractions": "9.0.4" - } - }, "Microsoft.Extensions.Configuration.Json": { "type": "Direct", "requested": "[9.0.4, )", @@ -68,6 +48,20 @@ "resolved": "9.0.6", "contentHash": "1HJCAbwukNEoYbHgHbKHmenU0V/0huw8+i7Qtf5rLUG1E+3kEwRJQxpwD3wbTEagIgPSQisNgJTvmUX9yYVc6g==" }, + "Microsoft.Extensions.Http": { + "type": "Direct", + "requested": "[9.0.8, )", + "resolved": "9.0.8", + "contentHash": "jDj+4aDByk47oESlDDTtk6LWzlXlmoCsjCn6ihd+i9OntN885aPLszUII5+w0B/7wYSZcS3KdjqLAIhKLSiBXQ==", + "dependencies": { + "Microsoft.Extensions.Configuration.Abstractions": "9.0.8", + "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.8", + "Microsoft.Extensions.Diagnostics": "9.0.8", + "Microsoft.Extensions.Logging": "9.0.8", + "Microsoft.Extensions.Logging.Abstractions": "9.0.8", + "Microsoft.Extensions.Options": "9.0.8" + } + }, "Microsoft.Extensions.Logging.Console": { "type": "Direct", "requested": "[9.0.4, )", @@ -118,26 +112,45 @@ "resolved": "13.0.3", "contentHash": "HrC5BXdl00IP9zeV+0Z848QWPAoCr9P3bDEZguI+gkLcBKAOxix/tLEAAHC+UvDNPv4a2d18lOReHMOagPa+zQ==" }, - "OpenTelemetry": { + "OpenTelemetry.Exporter.OpenTelemetryProtocol": { "type": "Direct", "requested": "[1.12.0, )", "resolved": "1.12.0", - "contentHash": "aIEu2O3xFOdwIVH0AJsIHPIMH1YuX18nzu7BHyaDNQ6NWSk4Zyrs9Pp6y8SATuSbvdtmvue4mj/QZ3838srbwA==", + "contentHash": "7LzQSPhz5pNaL4xZgT3wkZODA1NLrEq3bet8KDHgtaJ9q+VNP7wmiZky8gQfMkB4FXuI/pevT8ZurL4p5997WA==", "dependencies": { - "Microsoft.Extensions.Diagnostics.Abstractions": "9.0.0", - "Microsoft.Extensions.Logging.Configuration": "9.0.0", - "OpenTelemetry.Api.ProviderBuilderExtensions": "1.12.0" + "OpenTelemetry": "1.12.0" } }, - "OpenTelemetry.Exporter.OpenTelemetryProtocol": { + "OpenTelemetry.Extensions.Hosting": { "type": "Direct", "requested": "[1.12.0, )", "resolved": "1.12.0", - "contentHash": "7LzQSPhz5pNaL4xZgT3wkZODA1NLrEq3bet8KDHgtaJ9q+VNP7wmiZky8gQfMkB4FXuI/pevT8ZurL4p5997WA==", + "contentHash": "6/8O6rsJRwslg5/Fm3bscBelw4Yh9T9CN24p7cAsuEFkrmmeSO9gkYUCK02Qi+CmPM2KHYTLjKi0lJaCsDMWQA==", "dependencies": { + "Microsoft.Extensions.Hosting.Abstractions": "9.0.0", "OpenTelemetry": "1.12.0" } }, + "OpenTelemetry.Instrumentation.Http": { + "type": "Direct", + "requested": "[1.12.0, )", + "resolved": "1.12.0", + "contentHash": "0rW+MbHgUQAdbvBtRxPYoQBosbNdWegL7cYkRlxq+KQ/VFyU8itt4pWTccmu1/FWmTgqJyT3LaujyDZoRrm8Yg==", + "dependencies": { + "Microsoft.Extensions.Configuration": "9.0.0", + "Microsoft.Extensions.Options": "9.0.0", + "OpenTelemetry.Api.ProviderBuilderExtensions": "[1.12.0, 2.0.0)" + } + }, + "OpenTelemetry.Instrumentation.Runtime": { + "type": "Direct", + "requested": "[1.12.0, )", + "resolved": "1.12.0", + "contentHash": "xmd0TAm2x+T3ztdf5BolIwLPh+Uy6osaBeIQtCXv611PN7h/Pnhsjg5lU2hkAWj7M7ns74U5wtVpS8DXmJ+94w==", + "dependencies": { + "OpenTelemetry.Api": "[1.12.0, 2.0.0)" + } + }, "Swashbuckle.AspNetCore": { "type": "Direct", "requested": "[8.1.1, )", @@ -168,35 +181,23 @@ }, "Unobtanium.Web.Proxy": { "type": "Direct", - "requested": "[0.1.5, )", - "resolved": "0.1.5", - "contentHash": "HiICGm0e44+i4aVHpLn+aphmSC2eQnDvlTttw1rE0hntOZKoLGRy37sydqqbRP1ZokMf3Mt0GEgSWxDwnucKGg==", + "requested": "[0.9.1-beta-2, )", + "resolved": "0.9.1-beta-2", + "contentHash": "/nRDmaQQ9xRqWiSXf4mAN32anOM/e8aZcnD/Kty376d1JliqiKmR58sHS/yO8LITCo2eSGzH+cobm+e8MmpTPQ==", "dependencies": { - "BouncyCastle.Cryptography": "2.4.0", - "Microsoft.Extensions.Logging.Abstractions": "8.0.1", - "System.Runtime.CompilerServices.Unsafe": "6.0.0" + "Unobtanium.Web.Proxy.Events": "0.9.1-beta-2" } }, "Azure.Core": { "type": "Transitive", - "resolved": "1.44.1", - "contentHash": "YyznXLQZCregzHvioip07/BkzjuWNXogJEVz9T5W6TwjNr17ax41YGzYMptlo2G10oLCuVPoyva62y0SIRDixg==", + "resolved": "1.46.1", + "contentHash": "iE5DPOlGsN5kCkF4gN+vasN1RihO0Ypie92oQ5tohQYiokmnrrhLnee+3zcE8n7vB6ZAzhPTfUGAEXX/qHGkYA==", "dependencies": { - "Microsoft.Bcl.AsyncInterfaces": "6.0.0", - "System.ClientModel": "1.1.0", - "System.Diagnostics.DiagnosticSource": "6.0.1", - "System.Memory.Data": "6.0.0", - "System.Numerics.Vectors": "4.5.0", - "System.Text.Encodings.Web": "6.0.0", - "System.Text.Json": "6.0.10", - "System.Threading.Tasks.Extensions": "4.5.4" + "Microsoft.Bcl.AsyncInterfaces": "8.0.0", + "System.ClientModel": "1.4.1", + "System.Memory.Data": "6.0.1" } }, - "BouncyCastle.Cryptography": { - "type": "Transitive", - "resolved": "2.4.0", - "contentHash": "SwXsAV3sMvAU/Nn31pbjhWurYSjJ+/giI/0n6tCrYoupEK34iIHCuk3STAd9fx8yudM85KkLSVdn951vTng/vQ==" - }, "Markdig": { "type": "Transitive", "resolved": "0.41.3", @@ -204,8 +205,8 @@ }, "Microsoft.Bcl.AsyncInterfaces": { "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "UcSjPsst+DfAdJGVDsu346FX0ci0ah+lw3WRtn18NUwEqRt70HaOQ7lI72vy3+1LxtqI3T5GWwV39rQSrCzAeg==" + "resolved": "8.0.0", + "contentHash": "3WA9q9yVqJp222P3x1wYIGDAkpjAku0TMUaaQV22g6L67AI0LdOIrVS7Ht2vJfLHGSPVuqN94vIr15qn+HEkHw==" }, "Microsoft.Data.Sqlite.Core": { "type": "Transitive", @@ -287,12 +288,29 @@ "Microsoft.Extensions.Primitives": "9.0.4" } }, + "Microsoft.Extensions.Configuration": { + "type": "Transitive", + "resolved": "9.0.8", + "contentHash": "6m+8Xgmf8UWL0p/oGqBM+0KbHE5/ePXbV1hKXgC59zEv0aa0DW5oiiyxDbK5kH5j4gIvyD5uWL0+HadKBJngvQ==", + "dependencies": { + "Microsoft.Extensions.Configuration.Abstractions": "9.0.8", + "Microsoft.Extensions.Primitives": "9.0.8" + } + }, "Microsoft.Extensions.Configuration.Abstractions": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "0LN/DiIKvBrkqp7gkF3qhGIeZk6/B63PthAHjQsxymJfIBcz0kbf4/p/t4lMgggVxZ+flRi5xvTwlpPOoZk8fg==", + "resolved": "9.0.8", + "contentHash": "yNou2KM35RvzOh4vUFtl2l33rWPvOCoba+nzEDJ+BgD8aOL/jew4WPCibQvntRfOJ2pJU8ARygSMD+pdjvDHuA==", "dependencies": { - "Microsoft.Extensions.Primitives": "9.0.4" + "Microsoft.Extensions.Primitives": "9.0.8" + } + }, + "Microsoft.Extensions.Configuration.Binder": { + "type": "Transitive", + "resolved": "9.0.8", + "contentHash": "0vK9DnYrYChdiH3yRZWkkp4x4LbrfkWEdBc5HOsQ8t/0CLOWKXKkkhOE8A1shlex0hGydbGrhObeypxz/QTm+w==", + "dependencies": { + "Microsoft.Extensions.Configuration.Abstractions": "9.0.8" } }, "Microsoft.Extensions.Configuration.FileExtensions": { @@ -309,29 +327,39 @@ }, "Microsoft.Extensions.DependencyInjection": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "f2MTUaS2EQ3lX4325ytPAISZqgBfXmY0WvgD80ji6Z20AoDNiCESxsqo6mFRwHJD/jfVKRw9FsW6+86gNre3ug==", + "resolved": "9.0.8", + "contentHash": "JJjI2Fa+QtZcUyuNjbKn04OjIUX5IgFGFu/Xc+qvzh1rXdZHLcnqqVXhR4093bGirTwacRlHiVg1XYI9xum6QQ==", "dependencies": { - "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.4" + "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.8" } }, "Microsoft.Extensions.DependencyInjection.Abstractions": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "UI0TQPVkS78bFdjkTodmkH0Fe8lXv9LnhGFKgKrsgUJ5a5FVdFRcgjIkBVLbGgdRhxWirxH/8IXUtEyYJx6GQg==" + "resolved": "9.0.8", + "contentHash": "xY3lTjj4+ZYmiKIkyWitddrp1uL5uYiweQjqo4BKBw01ZC4HhcfgLghDpPZcUlppgWAFqFy9SgkiYWOMx365pw==" }, "Microsoft.Extensions.DependencyModel": { "type": "Transitive", "resolved": "9.0.4", "contentHash": "ACtnvl3H3M/f8Z42980JxsNu7V9PPbzys4vBs83ZewnsgKd7JeYK18OMPo0g+MxAHrpgMrjmlinXDiaSRPcVnA==" }, + "Microsoft.Extensions.Diagnostics": { + "type": "Transitive", + "resolved": "9.0.8", + "contentHash": "BKkLCFXzJvNmdngeYBf72VXoZqTJSb1orvjdzDLaGobicoGFBPW8ug2ru1nnEewMEwJzMgnsjHQY8EaKWmVhKg==", + "dependencies": { + "Microsoft.Extensions.Configuration": "9.0.8", + "Microsoft.Extensions.Diagnostics.Abstractions": "9.0.8", + "Microsoft.Extensions.Options.ConfigurationExtensions": "9.0.8" + } + }, "Microsoft.Extensions.Diagnostics.Abstractions": { "type": "Transitive", - "resolved": "9.0.0", - "contentHash": "1K8P7XzuzX8W8pmXcZjcrqS6x5eSSdvhQohmcpgiQNY/HlDAlnrhR9dvlURfFz428A+RTCJpUyB+aKTA6AgVcQ==", + "resolved": "9.0.8", + "contentHash": "UDY7blv4DCyIJ/8CkNrQKLaAZFypXQavRZ2DWf/2zi1mxYYKKw2t8AOCBWxNntyPZHPGhtEmL3snFM98ADZqTw==", "dependencies": { - "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0", - "Microsoft.Extensions.Options": "9.0.0" + "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.8", + "Microsoft.Extensions.Options": "9.0.8" } }, "Microsoft.Extensions.FileProviders.Abstractions": { @@ -352,22 +380,34 @@ "Microsoft.Extensions.Primitives": "9.0.4" } }, + "Microsoft.Extensions.Hosting.Abstractions": { + "type": "Transitive", + "resolved": "9.0.0", + "contentHash": "yUKJgu81ExjvqbNWqZKshBbLntZMbMVz/P7Way2SBx7bMqA08Mfdc9O7hWDKAiSp+zPUGT6LKcSCQIPeDK+CCw==", + "dependencies": { + "Microsoft.Extensions.Configuration.Abstractions": "9.0.0", + "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0", + "Microsoft.Extensions.Diagnostics.Abstractions": "9.0.0", + "Microsoft.Extensions.FileProviders.Abstractions": "9.0.0", + "Microsoft.Extensions.Logging.Abstractions": "9.0.0" + } + }, "Microsoft.Extensions.Logging": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "xW6QPYsqhbuWBO9/1oA43g/XPKbohJx+7G8FLQgQXIriYvY7s+gxr2wjQJfRoPO900dvvv2vVH7wZovG+M1m6w==", + "resolved": "9.0.8", + "contentHash": "Z/7ze+0iheT7FJeZPqJKARYvyC2bmwu3whbm/48BJjdlGVvgDguoCqJIkI/67NkroTYobd5geai1WheNQvWrgA==", "dependencies": { - "Microsoft.Extensions.DependencyInjection": "9.0.4", - "Microsoft.Extensions.Logging.Abstractions": "9.0.4", - "Microsoft.Extensions.Options": "9.0.4" + "Microsoft.Extensions.DependencyInjection": "9.0.8", + "Microsoft.Extensions.Logging.Abstractions": "9.0.8", + "Microsoft.Extensions.Options": "9.0.8" } }, "Microsoft.Extensions.Logging.Abstractions": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "0MXlimU4Dud6t+iNi5NEz3dO2w1HXdhoOLaYFuLPCjAsvlPQGwOT6V2KZRMLEhCAm/stSZt1AUv0XmDdkjvtbw==", + "resolved": "9.0.8", + "contentHash": "pYnAffJL7ARD/HCnnPvnFKSIHnTSmWz84WIlT9tPeQ4lHNiu0Az7N/8itihWvcF8sT+VVD5lq8V+ckMzu4SbOw==", "dependencies": { - "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.4" + "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.8" } }, "Microsoft.Extensions.Logging.Configuration": { @@ -387,34 +427,34 @@ }, "Microsoft.Extensions.Options": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "fiFI2+58kicqVZyt/6obqoFwHiab7LC4FkQ3mmiBJ28Yy4fAvy2+v9MRnSvvlOO8chTOjKsdafFl/K9veCPo5g==", + "resolved": "9.0.8", + "contentHash": "OmTaQ0v4gxGQkehpwWIqPoEiwsPuG/u4HUsbOFoWGx4DKET2AXzopnFe/fE608FIhzc/kcg2p8JdyMRCCUzitQ==", "dependencies": { - "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.4", - "Microsoft.Extensions.Primitives": "9.0.4" + "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.8", + "Microsoft.Extensions.Primitives": "9.0.8" } }, "Microsoft.Extensions.Options.ConfigurationExtensions": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "aridVhAT3Ep+vsirR1pzjaOw0Jwiob6dc73VFQn2XmDfBA2X98M8YKO1GarvsXRX7gX1Aj+hj2ijMzrMHDOm0A==", + "resolved": "9.0.8", + "contentHash": "eW2s6n06x0w6w4nsX+SvpgsFYkl+Y0CttYAt6DKUXeqprX+hzNqjSfOh637fwNJBg7wRBrOIRHe49gKiTgJxzQ==", "dependencies": { - "Microsoft.Extensions.Configuration.Abstractions": "9.0.4", - "Microsoft.Extensions.Configuration.Binder": "9.0.4", - "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.4", - "Microsoft.Extensions.Options": "9.0.4", - "Microsoft.Extensions.Primitives": "9.0.4" + "Microsoft.Extensions.Configuration.Abstractions": "9.0.8", + "Microsoft.Extensions.Configuration.Binder": "9.0.8", + "Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.8", + "Microsoft.Extensions.Options": "9.0.8", + "Microsoft.Extensions.Primitives": "9.0.8" } }, "Microsoft.Extensions.Primitives": { "type": "Transitive", - "resolved": "9.0.4", - "contentHash": "SPFyMjyku1nqTFFJ928JAMd0QnRe4xjE7KeKnZMWXf3xk+6e0WiOZAluYtLdbJUXtsl2cCRSi8cBquJ408k8RA==" + "resolved": "9.0.8", + "contentHash": "tizSIOEsIgSNSSh+hKeUVPK7xmTIjR8s+mJWOu1KXV3htvNQiPMFRMO17OdI1y/4ZApdBVk49u/08QGC9yvLug==" }, "Microsoft.Identity.Client": { "type": "Transitive", - "resolved": "4.67.2", - "contentHash": "37t0TfekfG6XM8kue/xNaA66Qjtti5Qe1xA41CK+bEd8VD76/oXJc+meFJHGzygIC485dCpKoamG/pDfb9Qd7Q==", + "resolved": "4.73.1", + "contentHash": "NnDLS8QwYqO5ZZecL2oioi1LUqjh5Ewk4bMLzbgiXJbQmZhDLtKwLxL3DpGMlQAJ2G4KgEnvGPKa+OOgffeJbw==", "dependencies": { "Microsoft.IdentityModel.Abstractions": "6.35.0", "System.Diagnostics.DiagnosticSource": "6.0.1" @@ -422,10 +462,10 @@ }, "Microsoft.Identity.Client.Extensions.Msal": { "type": "Transitive", - "resolved": "4.67.2", - "contentHash": "DKs+Lva6csEUZabw+JkkjtFgVmcXh4pJeQy5KH5XzPOaKNoZhAMYj1qpKd97qYTZKXIFH12bHPk0DA+6krw+Cw==", + "resolved": "4.73.1", + "contentHash": "xDztAiV2F0wI0W8FLKv5cbaBefyLD6JVaAsvgSN7bjWNCzGYzHbcOEIP5s4TJXUpQzMfUyBsFl1mC6Zmgpz0PQ==", "dependencies": { - "Microsoft.Identity.Client": "4.67.2", + "Microsoft.Identity.Client": "4.73.1", "System.Security.Cryptography.ProtectedData": "4.5.0" } }, @@ -498,6 +538,16 @@ "Newtonsoft.Json": "13.0.3" } }, + "OpenTelemetry": { + "type": "Transitive", + "resolved": "1.12.0", + "contentHash": "aIEu2O3xFOdwIVH0AJsIHPIMH1YuX18nzu7BHyaDNQ6NWSk4Zyrs9Pp6y8SATuSbvdtmvue4mj/QZ3838srbwA==", + "dependencies": { + "Microsoft.Extensions.Diagnostics.Abstractions": "9.0.0", + "Microsoft.Extensions.Logging.Configuration": "9.0.0", + "OpenTelemetry.Api.ProviderBuilderExtensions": "1.12.0" + } + }, "OpenTelemetry.Api": { "type": "Transitive", "resolved": "1.12.0", @@ -578,11 +628,11 @@ }, "System.ClientModel": { "type": "Transitive", - "resolved": "1.1.0", - "contentHash": "UocOlCkxLZrG2CKMAAImPcldJTxeesHnHGHwhJ0pNlZEvEXcWKuQvVOER2/NiOkJGRJk978SNdw3j6/7O9H1lg==", + "resolved": "1.4.1", + "contentHash": "MY7eFGKp+Hu7Ciub8wigQ0odGrkml4eTjUy8d5Bu2eGAVvm8Qskkq+YuXiiS5wMJGq7iSvqseV4skd5WxTUdDA==", "dependencies": { - "System.Memory.Data": "1.0.2", - "System.Text.Json": "6.0.9" + "Microsoft.Extensions.Logging.Abstractions": "8.0.3", + "System.Memory.Data": "6.0.1" } }, "System.Diagnostics.DiagnosticSource": { @@ -597,44 +647,26 @@ }, "System.Memory.Data": { "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "ntFHArH3I4Lpjf5m4DCXQHJuGwWPNVJPaAvM95Jy/u+2Yzt2ryiyIN04LAogkjP9DeRcEOiviAjQotfmPq/FrQ==", - "dependencies": { - "System.Text.Json": "6.0.0" - } - }, - "System.Numerics.Vectors": { - "type": "Transitive", - "resolved": "4.5.0", - "contentHash": "QQTlPTl06J/iiDbJCiepZ4H//BVraReU4O4EoRw1U02H5TLUIT7xn3GnDp9AXPSlJUDyFs4uWjWafNX6WrAojQ==" - }, - "System.Runtime.CompilerServices.Unsafe": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "/iUeP3tq1S0XdNNoMz5C9twLSrM/TH+qElHkXWaPvuNOt+99G75NrV0OS2EqHx5wMN7popYjpc8oTjC1y16DLg==" + "resolved": "6.0.1", + "contentHash": "yliDgLh9S9Mcy5hBIdZmX6yphYIW3NH+3HN1kV1m7V1e0s7LNTw/tHNjJP4U9nSMEgl3w1TzYv/KA1Tg9NYy6w==" }, "System.Security.Cryptography.ProtectedData": { "type": "Transitive", "resolved": "4.5.0", "contentHash": "wLBKzFnDCxP12VL9ANydSYhk59fC4cvOr9ypYQLPnAj48NQIhqnjdD2yhP8yEKyBJEjERWS9DisKL7rX5eU25Q==" }, - "System.Text.Encodings.Web": { - "type": "Transitive", - "resolved": "6.0.0", - "contentHash": "Vg8eB5Tawm1IFqj4TVK1czJX89rhFxJo9ELqc/Eiq0eXy13RK00eubyU6TJE6y+GQXjyV5gSfiewDUZjQgSE0w==", - "dependencies": { - "System.Runtime.CompilerServices.Unsafe": "6.0.0" - } - }, "System.Text.Json": { "type": "Transitive", "resolved": "9.0.4", "contentHash": "pYtmpcO6R3Ef1XilZEHgXP2xBPVORbYEzRP7dl0IAAbN8Dm+kfwio8aCKle97rAWXOExr292MuxWYurIuwN62g==" }, - "System.Threading.Tasks.Extensions": { + "Unobtanium.Web.Proxy.Events": { "type": "Transitive", - "resolved": "4.5.4", - "contentHash": "zteT+G8xuGu6mS+mzDzYXbzS7rd3K6Fjb9RiZlYlJPam2/hU7JCBZBVEcywNuR+oZ1ncTvc/cq0faRr3P01OVg==" + "resolved": "0.9.1-beta-2", + "contentHash": "OaDCVpnDYah/2DDqAwyhFwl/wt1Ggat8yGXF8GqtLDuUqxv/cYtowjI6ZMnKmLM/Bkz6gMM4lSITY0cO0rPVgA==", + "dependencies": { + "Microsoft.Extensions.Logging.Abstractions": "8.0.1" + } }, "YamlDotNet": { "type": "Transitive", @@ -654,7 +686,6 @@ "Newtonsoft.Json.Schema": "[4.0.1, )", "Scriban": "[6.2.1, )", "System.CommandLine": "[2.0.0-beta5.25306.1, )", - "Unobtanium.Web.Proxy": "[0.1.5, )", "YamlDotNet": "[16.3.0, )" } }