diff --git a/internal/deploy/callback_addon.go b/internal/deploy/callback_addon.go index db4ad95..0348a1f 100644 --- a/internal/deploy/callback_addon.go +++ b/internal/deploy/callback_addon.go @@ -7,14 +7,14 @@ import ( "net" "net/http" "strings" - "testing" + "sync" + "sync/atomic" "time" "github.com/matrix-org/complement/ct" - "github.com/matrix-org/complement/must" ) -var lastTestName string +var lastTestName atomic.Value = atomic.Value{} type CallbackData struct { Method string `json:"method"` @@ -36,16 +36,41 @@ func (cd CallbackData) String() string { return fmt.Sprintf("%s %s (token=%s) req_len=%d => HTTP %v", cd.Method, cd.URL, cd.AccessToken, len(cd.RequestBody), cd.ResponseCode) } -// NewCallbackServer runs a local HTTP server that can read callbacks from mitmproxy. -// Returns the URL of the callback server for use with WithMITMOptions, along with a close function -// which should be called when the test finishes to shut down the HTTP server. -func NewCallbackServer(t *testing.T, hostnameRunningComplement string, cb func(CallbackData) *CallbackResponse) (callbackURL string, close func()) { - if lastTestName != "" { - t.Logf("WARNING[%s]: NewCallbackServer called without closing the last one. Check test '%s'", t.Name(), lastTestName) - } - lastTestName = t.Name() - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +const ( + requestPath = "/request" + responsePath = "/response" +) + +type CallbackServer struct { + srv *http.Server + mux *http.ServeMux + baseURL string + + mu *sync.Mutex + onRequest http.HandlerFunc + onResponse http.HandlerFunc +} + +func (s *CallbackServer) SetOnRequestCallback(t ct.TestLike, cb func(CallbackData) *CallbackResponse) (callbackURL string) { + s.mu.Lock() + defer s.mu.Unlock() + s.onRequest = s.createHandler(t, cb) + return s.baseURL + requestPath +} +func (s *CallbackServer) SetOnResponseCallback(t ct.TestLike, cb func(CallbackData) *CallbackResponse) (callbackURL string) { + s.mu.Lock() + defer s.mu.Unlock() + s.onResponse = s.createHandler(t, cb) + return s.baseURL + responsePath +} + +// Shut down the server. +func (s *CallbackServer) Close() { + s.srv.Close() + lastTestName.Store("") +} +func (s *CallbackServer) createHandler(t ct.TestLike, cb func(CallbackData) *CallbackResponse) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { var data CallbackData if err := json.NewDecoder(r.Body).Decode(&data); err != nil { ct.Errorf(t, "error decoding json: %s", err) @@ -74,18 +99,60 @@ func NewCallbackServer(t *testing.T, hostnameRunningComplement string, cb func(C } fmt.Println(string(cbResBytes)) w.Write(cbResBytes) - }) + } +} + +// NewCallbackServer runs a local HTTP server that can read callbacks from mitmproxy. +// Automatically listens on a high numbered port. Must be Close()d at the end of the test. +// Register callback handlers via CallbackServer.SetOnRequestCallback and CallbackServer.SetOnResponseCallback +func NewCallbackServer(t ct.TestLike, hostnameRunningComplement string) (*CallbackServer, error) { + last := lastTestName.Load() + if last != nil && last.(string) != "" { + t.Logf("WARNING[%s]: NewCallbackServer called without closing the last one. Check test '%s'", t.Name(), last.(string)) + } + lastTestName.Store(t.Name()) + mux := http.NewServeMux() + // listen on a random high numbered port ln, err := net.Listen("tcp", ":0") //nolint - must.NotError(t, "failed to listen on a tcp port", err) + if err != nil { + return nil, fmt.Errorf("failed to listen on a tcp port: %s", err) + } port := ln.Addr().(*net.TCPAddr).Port - srv := http.Server{ + srv := &http.Server{ Addr: fmt.Sprintf(":%d", port), Handler: mux, } go srv.Serve(ln) - return fmt.Sprintf("http://%s:%d", hostnameRunningComplement, port), func() { - srv.Close() - lastTestName = "" + + callbackServer := &CallbackServer{ + mux: mux, + srv: srv, + mu: &sync.Mutex{}, + baseURL: fmt.Sprintf("http://%s:%d", hostnameRunningComplement, port), } + mux.HandleFunc(requestPath, func(w http.ResponseWriter, r *http.Request) { + callbackServer.mu.Lock() + h := callbackServer.onRequest + callbackServer.mu.Unlock() + if h == nil { + w.WriteHeader(404) + w.Write([]byte(`{"error":"no request handler registered"}`)) + return + } + h(w, r) + }) + mux.HandleFunc(responsePath, func(w http.ResponseWriter, r *http.Request) { + callbackServer.mu.Lock() + h := callbackServer.onResponse + callbackServer.mu.Unlock() + if h == nil { + w.WriteHeader(404) + w.Write([]byte(`{"error":"no response handler registered"}`)) + return + } + h(w, r) + }) + + return callbackServer, nil } diff --git a/internal/deploy/callback_addon_test.go b/internal/deploy/callback_addon_test.go index d40214b..0178bc3 100644 --- a/internal/deploy/callback_addon_test.go +++ b/internal/deploy/callback_addon_test.go @@ -38,9 +38,10 @@ func TestCallbackAddon(t *testing.T) { }) testCases := []struct { - name string - filter string - inner func(t *testing.T, checker *checker) + name string + filter string + needsRequestCallback bool + inner func(t *testing.T, checker *checker) }{ { name: "works", @@ -233,8 +234,32 @@ func TestCallbackAddon(t *testing.T) { must.Equal(t, gjson.ParseBytes(body).Get("foo").Str, "bar", "response body was not modified") }, }, - // TODO: can block requests - // TODO: migrate functionality from status_code addon + { + name: "can block requests and modify response codes and bodies", + filter: "~m PUT", + needsRequestCallback: true, + inner: func(t *testing.T, checker *checker) { + checker.expect(&callbackRequest{ + OnRequestCallback: func(cd CallbackData) *CallbackResponse { + return &CallbackResponse{ + RespondStatusCode: 200, + RespondBody: json.RawMessage(`{"yep": "ok"}`), + } + }, + }) + // This is a PUT so will be intercepted + res := client.MustSetGlobalAccountData(t, "this_wont_go_through", map[string]any{"foo": "bar"}) + checker.wait() + must.Equal(t, res.StatusCode, 200, "response code was not set") + body, err := io.ReadAll(res.Body) + must.NotError(t, "failed to read CSAPI response", err) + must.Equal(t, gjson.ParseBytes(body).Get("yep").Str, "ok", "response body was not set") + + // now check it didn't go through by doing a GET which isn't intercepted + res = client.GetGlobalAccountData(t, "this_wont_go_through") + must.Equal(t, res.StatusCode, 404, "GET returned data when the PUT should have been intercepted") + }, + }, } for _, tc := range testCases { @@ -244,25 +269,34 @@ func TestCallbackAddon(t *testing.T) { ch: make(chan callbackRequest, 3), mu: &sync.Mutex{}, } - callbackURL, close := NewCallbackServer( + cbServer, err := NewCallbackServer( t, deployment.GetConfig().HostnameRunningComplement, - func(cd CallbackData) *CallbackResponse { - return checker.onCallback(cd) - }, ) - defer close() - mitmClient := deployment.MITM() - mitmOpts := map[string]any{ - "callback": map[string]any{ - "callback_url": callbackURL, - }, + callbackURL := cbServer.SetOnResponseCallback(t, func(cd CallbackData) *CallbackResponse { + return checker.onResponseCallback(cd) + }) + var reqCallbackURL string + if tc.needsRequestCallback { + reqCallbackURL = cbServer.SetOnRequestCallback(t, func(cd CallbackData) *CallbackResponse { + return checker.onRequestCallback(cd) + }) + } + must.NotError(t, "failed to create callback server", err) + defer cbServer.Close() + callbackOpts := map[string]any{ + "callback_response_url": callbackURL, } if tc.filter != "" { - cb := mitmOpts["callback"].(map[string]any) - cb["filter"] = tc.filter - mitmOpts["callback"] = cb + callbackOpts["filter"] = tc.filter + } + if reqCallbackURL != "" { + callbackOpts["callback_request_url"] = reqCallbackURL } - lockID := mitmClient.lockOptions(t, mitmOpts) + + mitmClient := deployment.MITM() + lockID := mitmClient.lockOptions(t, map[string]any{ + "callback": callbackOpts, + }) tc.inner(t, checker) mitmClient.unlockOptions(t, lockID) }) @@ -270,11 +304,12 @@ func TestCallbackAddon(t *testing.T) { } type callbackRequest struct { - Method string - PathContains string - AccessToken string - ResponseCode int - OnCallback func(cd CallbackData) *CallbackResponse + Method string + PathContains string + AccessToken string + ResponseCode int + OnRequestCallback func(cd CallbackData) *CallbackResponse + OnCallback func(cd CallbackData) *CallbackResponse } type checker struct { @@ -285,7 +320,7 @@ type checker struct { noCallbacks bool } -func (c *checker) onCallback(cd CallbackData) *CallbackResponse { +func (c *checker) onResponseCallback(cd CallbackData) *CallbackResponse { c.mu.Lock() if c.noCallbacks { ct.Errorf(c.t, "wanted no callbacks but got %+v", cd) @@ -322,6 +357,16 @@ func (c *checker) onCallback(cd CallbackData) *CallbackResponse { return callbackResponse } +func (c *checker) onRequestCallback(cd CallbackData) *CallbackResponse { + c.mu.Lock() + cb := c.want.OnRequestCallback + c.mu.Unlock() + if cb != nil { + return cb(cd) + } + return nil +} + func (c *checker) expect(want *callbackRequest) { c.mu.Lock() defer c.mu.Unlock() @@ -340,7 +385,7 @@ func (c *checker) wait() { case got := <-c.ch: // we can't sanity check if there are callbacks involved, as we can't easily // pair responses up. - if c.want.OnCallback == nil && !reflect.DeepEqual(got, *c.want) { + if c.want.OnCallback == nil && c.want.OnRequestCallback == nil && !reflect.DeepEqual(got, *c.want) { ct.Fatalf(c.t, "checker: got success from a different request: did you forget to wait?"+ " Received %+v but expected +%v", got, c.want) } diff --git a/internal/deploy/mitm.go b/internal/deploy/mitm.go index 90e0bad..971d187 100644 --- a/internal/deploy/mitm.go +++ b/internal/deploy/mitm.go @@ -122,12 +122,14 @@ func (c *MITMConfiguration) Execute(inner func()) { c.mu.Lock() for _, pathConfig := range c.pathCfgs { if pathConfig.listener != nil { - callbackURL, closeCallbackServer := NewCallbackServer(c.t, c.client.hostnameRunningComplement, pathConfig.listener) - defer closeCallbackServer() + cbServer, err := NewCallbackServer(c.t, c.client.hostnameRunningComplement) + must.NotError(c.t, "failed to start callback server", err) + callbackURL := cbServer.SetOnResponseCallback(c.t, pathConfig.listener) + defer cbServer.Close() body["callback"] = map[string]any{ - "callback_url": callbackURL, - "filter": pathConfig.filter(), + "callback_response_url": callbackURL, + "filter": pathConfig.filter(), } } if pathConfig.blockRequest != nil { diff --git a/tests/mitmproxy_addons/callback.py b/tests/mitmproxy_addons/callback.py index 29e49cf..f553a7c 100644 --- a/tests/mitmproxy_addons/callback.py +++ b/tests/mitmproxy_addons/callback.py @@ -10,7 +10,7 @@ from urllib.error import HTTPError, URLError from datetime import datetime -# Callback will intercept a response and send a POST request to the provided callback_url, with +# Callback will intercept a request and/or response and send a POST request to the provided url, with # the following JSON object. Supports filters: https://docs.mitmproxy.org/stable/concepts-filters/ # { # method: "GET|PUT|...", @@ -20,10 +20,25 @@ # response_body: { some json object }, # response_code: 200, # } -# Currently this is a read-only callback. The response cannot be modified, but side-effects can be -# taken. For example, tests may wish to terminate a client prior to the delivery of a response but -# after the server has processed the request, or the test may wish to use the response as a -# synchronisation point for a Waiter. +# The response to this request can control what gets returned to the client. The response object: +# { +# "respond_status_code": 200, +# "respond_body": { "some": "json_object" } +# } +# If {} is sent back, the response is not modified. Likewise, if `respond_body` is set but +# `respond_status_code` is not, only the response body is modified, not the status code, and vice versa. +# +# To use this addon, configure it with these fields: +# - callback_request_url: the URL to send outbound requests to. This allows callbacks to intercept +# requests BEFORE they reach the server. The request/response struct in this +# callback is the same as `callback_response_url`, except `response_body` +# and `response_code` will be missing as the request hasn't been processed +# yet. To block the request from reaching the server, the callback needs to +# provide both `respond_status_code` and `respond_body`. +# - callback_response_url: the URL to send inbound responses to. This allows callbacks to modify +# response content. +# - filter: the mitmproxy filter to apply. If unset, ALL requests are eligible to go to the callback +# server. class Callback: def __init__(self): self.reset() @@ -32,7 +47,8 @@ def __init__(self): def reset(self): self.config = { - "callback_url": "", + "callback_request_url": "", + "callback_response_url": "", "filter": None, } @@ -40,7 +56,11 @@ def load(self, loader): loader.add_option( name="callback", typespec=dict, - default={"callback_url": "", "filter": None}, + default={ + "callback_request_url": "", + "callback_response_url": "", + "filter": None, + }, help="Change the callback url, with an optional filter", ) @@ -48,22 +68,43 @@ def configure(self, updates): if "callback" not in updates: self.reset() return - if ctx.options.callback is None or ctx.options.callback["callback_url"] == "": + if ctx.options.callback is None: self.reset() return self.config = ctx.options.callback new_filter = self.config.get('filter', None) - print(f"callback will hit {self.config['callback_url']} filter={new_filter}") + print(f"callback req_url={self.config.get('callback_request_url')} res_url={self.config.get('callback_response_url')} filter={new_filter}") if new_filter: self.filter = flowfilter.parse(new_filter) else: self.filter = self.matchall + async def request(self, flow): + # always ignore the controller + if flow.request.pretty_host == MITM_DOMAIN_NAME: + return + if self.config.get("callback_request_url", "") == "": + return # ignore requests if we aren't told a url + if not flowfilter.match(self.filter, flow): + return # ignore requests which don't match the filter + try: # e.g GET requests have no req body + req_body = flow.request.json() + except: + req_body = None + print(f'{datetime.now().strftime("%H:%M:%S.%f")} hitting request callback for {flow.request.url}') + callback_body = { + "method": flow.request.method, + "access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "), + "url": flow.request.url, + "request_body": req_body, + } + await self.send_callback(flow, self.config["callback_request_url"], callback_body) + async def response(self, flow): # always ignore the controller if flow.request.pretty_host == MITM_DOMAIN_NAME: return - if self.config["callback_url"] == "": + if self.config.get("callback_response_url","") == "": return # ignore responses if we aren't told a url if flowfilter.match(self.filter, flow): try: # e.g GET requests have no req body @@ -74,7 +115,7 @@ async def response(self, flow): res_body = flow.response.json() except: res_body = None - print(f'{datetime.now().strftime("%H:%M:%S.%f")} hitting callback for {flow.request.url}') + print(f'{datetime.now().strftime("%H:%M:%S.%f")} hitting response callback for {flow.request.url}') callback_body = { "method": flow.request.method, "access_token": flow.request.headers.get("Authorization", "").removeprefix("Bearer "), @@ -83,25 +124,37 @@ async def response(self, flow): "request_body": req_body, "response_body": res_body, } - try: - # use asyncio so we don't block other unrelated requests from being processed - async with aiohttp.request( - method="POST",url=self.config["callback_url"], timeout=aiohttp.ClientTimeout(total=10), - headers={"Content-Type": "application/json"}, - json=callback_body) as response: - print(f'{datetime.now().strftime("%H:%M:%S.%f")} callback for {flow.request.url} returned HTTP {response.status}') - test_response_body = await response.json() - # if the response includes some keys then we are modifying the response on a per-key basis. - if len(test_response_body) > 0: - respond_status_code = test_response_body.get("respond_status_code", flow.response.status_code) - respond_body = test_response_body.get("respond_body", res_body) - flow.response = Response.make( - respond_status_code, json.dumps(respond_body), - headers={ - "MITM-Proxy": "yes", # so we don't reprocess this - "Content-Type": "application/json", - }) + await self.send_callback(flow, self.config["callback_response_url"], callback_body) - except Exception as error: - print(f"ERR: callback for {flow.request.url} returned {error}") - print(f"ERR: callback, provided request body was {callback_body}") + async def send_callback(self, flow, url: str, body: dict): + try: + # use asyncio so we don't block other unrelated requests from being processed + async with aiohttp.request( + method="POST", + url=url, + timeout=aiohttp.ClientTimeout(total=10), + headers={"Content-Type": "application/json"}, + json=body, + ) as response: + print(f'{datetime.now().strftime("%H:%M:%S.%f")} callback for {flow.request.url} returned HTTP {response.status}') + if response.content_type != 'application/json': + err_response_body = await response.text() + print(f'ERR: callback server returned non-json: {err_response_body}') + raise Exception("callback server content-type: " + response.content_type) + test_response_body = await response.json() + # if the response includes some keys then we are modifying the response on a per-key basis. + if len(test_response_body) > 0: + # use what fields were provided preferentially. + # For requests: both fields must be provided so the default case won't execute. + # For responses: fields are optional but the default case is always specified. + respond_status_code = test_response_body.get("respond_status_code", body.get("response_code")) + respond_body = test_response_body.get("respond_body", body.get("response_body")) + flow.response = Response.make( + respond_status_code, json.dumps(respond_body), + headers={ + "MITM-Proxy": "yes", # so we don't reprocess this + "Content-Type": "application/json", + }) + except Exception as error: + print(f"ERR: callback for {flow.request.url} returned {error}") + print(f"ERR: callback, provided request body was {body}")