diff --git a/Modules/Reflection/IOperationInterceptor.cs b/Modules/Reflection/IOperationInterceptor.cs new file mode 100644 index 00000000..e6b76681 --- /dev/null +++ b/Modules/Reflection/IOperationInterceptor.cs @@ -0,0 +1,39 @@ +using GenHTTP.Api.Protocol; + +using GenHTTP.Modules.Reflection.Operations; + +namespace GenHTTP.Modules.Reflection; + +/// +/// A result returned by an interceptor. +/// +/// The payload of the response +public sealed class InterceptionResult(object? payload = null) : Result(payload); + +/// +/// A piece of logic to be executed before the +/// actual method invocation. Triggered by methods +/// annotated with the +/// attribute. +/// +public interface IOperationInterceptor +{ + + /// + /// Invoked after the instance has been created to configure + /// the interceptor with the originally used attribute. Allows + /// the interceptor to read configuration data as needed. + /// + /// The original attribute instance on the method definition + void Configure(object attribute); + + /// + /// Invoked on every operation call by the client. + /// + /// The request which caused this invocation + /// The currently executed operation + /// The operation arguments as derived by the framework + /// If a result is returned, it will be converted into a response and the method is not invoked + ValueTask InterceptAsync(IRequest request, Operation operation, IReadOnlyDictionary arguments); + +} diff --git a/Modules/Reflection/InterceptWithAttribute.cs b/Modules/Reflection/InterceptWithAttribute.cs new file mode 100644 index 00000000..5c5b1638 --- /dev/null +++ b/Modules/Reflection/InterceptWithAttribute.cs @@ -0,0 +1,16 @@ +namespace GenHTTP.Modules.Reflection; + +/// +/// When annotated on a service method, the method handler +/// will create an instance of T and invoke it before +/// the actual method invocation. +/// +/// The type of interceptor to be used +/// +/// Allows to implement concerns on operation level such as authorization. +/// +[AttributeUsage(AttributeTargets.Method)] +public class InterceptWithAttribute : Attribute where T : IOperationInterceptor, new() +{ + +} diff --git a/Modules/Reflection/MethodHandler.cs b/Modules/Reflection/MethodHandler.cs index 08b04e76..f009edcf 100644 --- a/Modules/Reflection/MethodHandler.cs +++ b/Modules/Reflection/MethodHandler.cs @@ -1,9 +1,11 @@ using System.Reflection; using System.Runtime.ExceptionServices; using System.Text.RegularExpressions; + using GenHTTP.Api.Content; using GenHTTP.Api.Protocol; using GenHTTP.Api.Routing; + using GenHTTP.Modules.Conversion.Serializers.Forms; using GenHTTP.Modules.Reflection.Operations; @@ -19,7 +21,7 @@ namespace GenHTTP.Modules.Reflection; /// public sealed class MethodHandler : IHandler { - private static readonly object?[] NoArguments = []; + private static readonly Dictionary NoArguments = []; #region Get-/Setters @@ -63,12 +65,19 @@ public MethodHandler(Operation operation, object instance, IMethodConfiguration { var arguments = await GetArguments(request); - var result = Invoke(arguments); + var interception = await InterceptAsync(request, arguments); + + if (interception is not null) + { + return interception; + } + + var result = Invoke(arguments.Values.ToArray()); return await ResponseProvider.GetResponseAsync(request, this, Operation, await UnwrapAsync(result), null); } - private async ValueTask GetArguments(IRequest request) + private async ValueTask> GetArguments(IRequest request) { var targetParameters = Operation.Method.GetParameters(); @@ -85,7 +94,7 @@ public MethodHandler(Operation operation, object instance, IMethodConfiguration if (targetParameters.Length > 0) { - var targetArguments = new object?[targetParameters.Length]; + var targetArguments = new Dictionary(targetParameters.Length); var bodyArguments = FormFormat.GetContent(request); @@ -97,7 +106,7 @@ public MethodHandler(Operation operation, object instance, IMethodConfiguration { if (Operation.Arguments.TryGetValue(par.Name, out var arg)) { - targetArguments[i] = arg.Source switch + targetArguments[arg.Name] = arg.Source switch { OperationArgumentSource.Injected => ArgumentProvider.GetInjectedArgument(request, this, arg, Registry), OperationArgumentSource.Path => ArgumentProvider.GetPathArgument(arg, sourceParameters, Registry), @@ -119,6 +128,22 @@ public MethodHandler(Operation operation, object instance, IMethodConfiguration public ValueTask PrepareAsync() => ValueTask.CompletedTask; + private async ValueTask InterceptAsync(IRequest request, IReadOnlyDictionary arguments) + { + if (Operation.Interceptors.Count > 0) + { + foreach (var interceptor in Operation.Interceptors) + { + if (await interceptor.InterceptAsync(request, Operation, arguments) is IResultWrapper result) + { + return await ResponseProvider.GetResponseAsync(request, this, Operation, result.Payload, (r) => result.Apply(r)); + } + } + } + + return null; + } + private object? Invoke(object?[] arguments) { try diff --git a/Modules/Reflection/Operations/InterceptorAnalyzer.cs b/Modules/Reflection/Operations/InterceptorAnalyzer.cs new file mode 100644 index 00000000..750a407d --- /dev/null +++ b/Modules/Reflection/Operations/InterceptorAnalyzer.cs @@ -0,0 +1,46 @@ +using System.Reflection; + +namespace GenHTTP.Modules.Reflection.Operations; + +public static class InterceptorAnalyzer +{ + + public static IReadOnlyList GetInterceptors(MethodInfo method) + { + var interceptors = new List(); + + foreach (var attribute in method.GetCustomAttributes(typeof(InterceptWithAttribute<>), true)) + { + var interceptorType = FindInterceptorType(attribute.GetType()); + + if (interceptorType != null) + { + if (Activator.CreateInstance(interceptorType) is IOperationInterceptor interceptor) + { + interceptor.Configure(attribute); + interceptors.Add(interceptor); + } + } + } + + return interceptors; + } + + private static Type? FindInterceptorType(Type attributeType) + { + var current = attributeType; + + while (current != null) + { + if (current.IsGenericType && current.GetGenericTypeDefinition() == typeof(InterceptWithAttribute<>)) + { + return current.GetGenericArguments()[0]; + } + + current = current.BaseType; + } + + return null; + } + +} diff --git a/Modules/Reflection/Operations/Operation.cs b/Modules/Reflection/Operations/Operation.cs index 449821b2..18c0c710 100644 --- a/Modules/Reflection/Operations/Operation.cs +++ b/Modules/Reflection/Operations/Operation.cs @@ -22,6 +22,11 @@ public sealed class Operation /// public IReadOnlyDictionary Arguments { get; } + /// + /// The interceptors to be executed for this operation. + /// + public IReadOnlyList Interceptors { get; } + /// /// The result generated by this operation. /// @@ -31,12 +36,13 @@ public sealed class Operation #region Initialization - public Operation(MethodInfo method, OperationPath path, OperationResult result, IReadOnlyDictionary arguments) + public Operation(MethodInfo method, OperationPath path, OperationResult result, IReadOnlyDictionary arguments, IReadOnlyList interceptors) { Method = method; Path = path; Result = result; Arguments = arguments; + Interceptors = interceptors; } #endregion diff --git a/Modules/Reflection/Operations/OperationBuilder.cs b/Modules/Reflection/Operations/OperationBuilder.cs index ff4962bf..1997c258 100644 --- a/Modules/Reflection/Operations/OperationBuilder.cs +++ b/Modules/Reflection/Operations/OperationBuilder.cs @@ -104,7 +104,9 @@ public static Operation Create(string? definition, MethodInfo method, MethodRegi var result = SignatureAnalyzer.GetResult(method, registry); - return new Operation(method, path, result, arguments); + var interceptors = InterceptorAnalyzer.GetInterceptors(method); + + return new Operation(method, path, result, arguments, interceptors); } private static bool CheckWildcardRoute(Type returnType) diff --git a/Testing/Acceptance/Modules/Reflection/InterceptionTests.cs b/Testing/Acceptance/Modules/Reflection/InterceptionTests.cs new file mode 100644 index 00000000..bb82ebc4 --- /dev/null +++ b/Testing/Acceptance/Modules/Reflection/InterceptionTests.cs @@ -0,0 +1,108 @@ +using System.Net; +using GenHTTP.Api.Content; +using GenHTTP.Api.Protocol; + +using GenHTTP.Modules.Functional; +using GenHTTP.Modules.Reflection; +using GenHTTP.Modules.Reflection.Operations; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace GenHTTP.Testing.Acceptance.Modules.Reflection; + +[TestClass] +public class InterceptionTests +{ + + #region Supporting data structures + + [AttributeUsage(AttributeTargets.Method)] + public class MyAttribute(string command) : InterceptWithAttribute + { + + public string Command => command; + + } + + public class MyInterceptor : IOperationInterceptor + { + + public string? Command { get; private set; } + + public void Configure(object attribute) + { + if (attribute is MyAttribute my) + { + Command = my.Command; + } + } + + public ValueTask InterceptAsync(IRequest request, Operation operation, IReadOnlyDictionary arguments) + { + if (Command == "intercept") + { + var result = new InterceptionResult("Nah"); + result.Status(ResponseStatus.Forbidden); + + return new(result); + } + + if (Command == "throw") + { + throw new ProviderException(ResponseStatus.Forbidden, "Nah"); + } + + return default; + } + + } + + #endregion + + #region Tests + + [TestMethod] + public async Task TestInterception() + { + var app = Inline.Create().Get([My("intercept")] () => 42); + + await using var host = await TestHost.RunAsync(app); + + using var response = await host.GetResponseAsync(); + + await response.AssertStatusAsync(HttpStatusCode.Forbidden); + + Assert.AreEqual("Nah", await response.GetContentAsync()); + } + + [TestMethod] + public async Task TestPassThrough() + { + var app = Inline.Create().Get([My("pass")] () => 42); + + await using var host = await TestHost.RunAsync(app); + + using var response = await host.GetResponseAsync(); + + await response.AssertStatusAsync(HttpStatusCode.OK); + + Assert.AreEqual("42", await response.GetContentAsync()); + } + + [TestMethod] + public async Task TestException() + { + var app = Inline.Create().Get([My("throw")] () => 42); + + await using var host = await TestHost.RunAsync(app); + + using var response = await host.GetResponseAsync(); + + await response.AssertStatusAsync(HttpStatusCode.Forbidden); + + AssertX.Contains("Nah", await response.GetContentAsync()); + } + + #endregion + +}