From 2fe46bc58a1f3fe09bd242015efebfaea722a4fe Mon Sep 17 00:00:00 2001 From: Adrien YHUEL Date: Wed, 18 Dec 2024 22:08:35 +0100 Subject: [PATCH] feat: :white_check_mark: separate to an utils function and add tests --- http.go | 26 ++------------------------ utils.go | 34 ++++++++++++++++++++++++++++++++++ utils_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 24 deletions(-) create mode 100644 utils_test.go diff --git a/http.go b/http.go index ef895f0..90c5c27 100644 --- a/http.go +++ b/http.go @@ -8,37 +8,15 @@ import ( "io" "net" "net/http" - "strconv" - "strings" "github.com/corazawaf/coraza/v3/types" - "github.com/caddyserver/caddy/v2/modules/caddyhttp" ) // Copied from https://github.com/corazawaf/coraza/blob/main/http/middleware.go func processRequest(tx types.Transaction, req *http.Request) (*types.Interruption, error) { - var ( - client string - cport int - ) - // IMPORTANT: Some http.Request.RemoteAddr implementations will not contain port or contain IPV6: [2001:db8::1]:8080 - idx := strings.LastIndexByte(req.RemoteAddr, ':') - if idx != -1 { - client = req.RemoteAddr[:idx] - cport, _ = strconv.Atoi(req.RemoteAddr[idx+1:]) - } - if address, ok := caddyhttp.GetVar(req.Context(), caddyhttp.ClientIPVarKey).(string); ok { - clientIp, clientPort, _ := net.SplitHostPort(address) - if clientIp != "" { - client = clientIp - } else if address != "" { - client = address - } - if clientPort != "" { - cport, _ = strconv.Atoi(clientPort) - } - } + + client, cport := getClientAddress(req) var in *types.Interruption // There is no socket access in the request object, so we neither know the server client nor port. diff --git a/utils.go b/utils.go index 94f8cbc..7e7700d 100644 --- a/utils.go +++ b/utils.go @@ -4,7 +4,11 @@ package coraza import ( + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "math/rand" + "net" + "net/http" + "strconv" "strings" "sync" "time" @@ -45,3 +49,33 @@ func randomString(n int) string { return sb.String() } + +func getClientAddress(req *http.Request) (string, int) { + + var ( + clientIp string + clientPort int + ) + + if address, ok := caddyhttp.GetVar(req.Context(), caddyhttp.ClientIPVarKey).(string); ok && len(address) > 0 { + ip, port, _ := net.SplitHostPort(address) + if ip != "" { + clientIp = ip + } else { + clientIp = address + } + clientPort, _ = strconv.Atoi(port) + } else { + idx := strings.LastIndexByte(req.RemoteAddr, ':') + if idx != -1 { + clientIp = req.RemoteAddr[:idx] + clientPort, _ = strconv.Atoi(req.RemoteAddr[idx+1:]) + } else { + clientIp = req.RemoteAddr + clientPort = 0 + } + } + + return clientIp, clientPort + +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..6dad9f3 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 The OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package coraza + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/stretchr/testify/require" +) + +func TestParsegClientAddress(t *testing.T) { + + remoteIp := "127.0.0.1" + remotePort := 9090 + clientIp := "127.0.0.2" + clientPort := 8080 + + req, _ := http.NewRequest("GET", "/", nil) + + req.RemoteAddr = fmt.Sprintf("%v:%v", remoteIp, remotePort) + ip, port := getClientAddress(req) + require.Equal(t, remoteIp, ip) + require.Equal(t, remotePort, port) + + req.RemoteAddr = remoteIp + ip, port = getClientAddress(req) + require.Equal(t, remoteIp, ip) + require.Equal(t, 0, port) + + req = req.WithContext(context.WithValue(req.Context(), caddyhttp.VarsCtxKey, make(map[string]any))) + req.RemoteAddr = fmt.Sprintf("%v:%v", remoteIp, remotePort) + + ip, port = getClientAddress(req) + require.Equal(t, remoteIp, ip) + require.Equal(t, remotePort, port) + + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, fmt.Sprintf("%v:%v", clientIp, clientPort)) + ip, port = getClientAddress(req) + require.Equal(t, clientIp, ip) + require.Equal(t, clientPort, port) + + caddyhttp.SetVar(req.Context(), caddyhttp.ClientIPVarKey, clientIp) + ip, port = getClientAddress(req) + require.Equal(t, clientIp, ip) + require.Equal(t, 0, port) +}