From 6c9a5626bd70567c6293b6045a6bfd648ae0a95d Mon Sep 17 00:00:00 2001 From: Toan Nguyen Date: Tue, 11 Jun 2024 00:38:36 +0700 Subject: [PATCH] fix: avoid modifying mutable request endpoints (#14) --- rest/connector_test.go | 32 +++++++++++++++++++++----------- rest/internal/request.go | 10 ++++------ rest/mutation.go | 3 +-- rest/query.go | 3 +-- rest/request.go | 4 ++-- 5 files changed, 29 insertions(+), 23 deletions(-) diff --git a/rest/connector_test.go b/rest/connector_test.go index 5cd3ea1..4e5bba1 100644 --- a/rest/connector_test.go +++ b/rest/connector_test.go @@ -167,19 +167,26 @@ func TestRESTConnector_authentication(t *testing.T) { } } }, - "arguments": {}, + "arguments": { + "status": { + "type": "literal", + "value": "available" + } + }, "collection_relationships": {} }`) - res, err := http.Post(fmt.Sprintf("%s/query", testServer.URL), "application/json", bytes.NewBuffer(reqBody)) - assert.NilError(t, err) - assertHTTPResponse(t, res, http.StatusOK, schema.QueryResponse{ - { - Rows: []map[string]any{ - {"__value": map[string]any{}}, + for i := 0; i < 2; i++ { + res, err := http.Post(fmt.Sprintf("%s/query", testServer.URL), "application/json", bytes.NewBuffer(reqBody)) + assert.NilError(t, err) + assertHTTPResponse(t, res, http.StatusOK, schema.QueryResponse{ + { + Rows: []map[string]any{ + {"__value": map[string]any{}}, + }, }, - }, - }) + }) + } }) t.Run("retry", func(t *testing.T) { @@ -513,8 +520,11 @@ func createMockServer(t *testing.T, apiKey string, bearerToken string) *httptest switch r.Method { case http.MethodGet: if r.Header.Get("Authorization") != fmt.Sprintf("Bearer %s", bearerToken) { - t.Errorf("invalid bearer token, expected %s, got %s", bearerToken, r.Header.Get("Authorization")) - t.FailNow() + t.Fatalf("invalid bearer token, expected %s, got %s", bearerToken, r.Header.Get("Authorization")) + return + } + if r.URL.Query().Encode() != "status=available" { + t.Fatalf("expected query param: status=available, got: %s", r.URL.Query().Encode()) return } writeResponse(w) diff --git a/rest/internal/request.go b/rest/internal/request.go index bbdb21a..daed1a0 100644 --- a/rest/internal/request.go +++ b/rest/internal/request.go @@ -81,13 +81,13 @@ func getHostFromServers(servers []rest.ServerConfig, serverIDs []string) (string } func buildDistributedRequestsWithOptions(request *RetryableRequest, restOptions *RESTOptions) ([]RetryableRequest, error) { - if strings.HasPrefix(request.RawRequest.URL, "http") { + if strings.HasPrefix(request.URL, "http") { return []RetryableRequest{*request}, nil } if !restOptions.Distributed || len(restOptions.Settings.Servers) == 1 { host, serverID := getHostFromServers(restOptions.Settings.Servers, restOptions.Servers) - request.URL = fmt.Sprintf("%s%s", host, request.RawRequest.URL) + request.URL = fmt.Sprintf("%s%s", host, request.URL) request.ServerID = serverID if err := request.applySettings(restOptions.Settings); err != nil { return nil, err @@ -118,7 +118,7 @@ func buildDistributedRequestsWithOptions(request *RetryableRequest, restOptions } req := RetryableRequest{ - URL: fmt.Sprintf("%s%s", host, request.RawRequest.URL), + URL: fmt.Sprintf("%s%s", host, request.URL), ServerID: serverID, RawRequest: request.RawRequest, ContentType: request.ContentType, @@ -217,7 +217,7 @@ func (req *RetryableRequest) applySecurity(serverConfig *rest.ServerConfig) erro case rest.APIKeyInQuery: value := securityScheme.Value.Value() if value != nil { - endpoint, err := url.Parse(req.RawRequest.URL) + endpoint, err := url.Parse(req.URL) if err != nil { return err } @@ -226,8 +226,6 @@ func (req *RetryableRequest) applySecurity(serverConfig *rest.ServerConfig) erro q.Add(securityScheme.Name, *securityScheme.Value.Value()) endpoint.RawQuery = q.Encode() req.URL = endpoint.String() - } else { - req.URL = req.RawRequest.URL } case rest.APIKeyInCookie: if securityScheme.Value != nil { diff --git a/rest/mutation.go b/rest/mutation.go index 8e35ffe..dfdb747 100644 --- a/rest/mutation.go +++ b/rest/mutation.go @@ -62,9 +62,8 @@ func (c *RESTConnector) execProcedure(ctx context.Context, operation *schema.Mut // 2. create and execute request // 3. evaluate response selection - procedure.Request.URL = endpoint restOptions.Settings = settings - httpRequest, err := c.createRequest(procedure.Request, headers, rawArgs) + httpRequest, err := c.createRequest(procedure.Request, endpoint, headers, rawArgs) if err != nil { return nil, err } diff --git a/rest/query.go b/rest/query.go index f0c5c16..9dc9898 100644 --- a/rest/query.go +++ b/rest/query.go @@ -69,9 +69,8 @@ func (c *RESTConnector) execQuery(ctx context.Context, request *schema.QueryRequ // 2. create and execute request // 3. evaluate response selection - function.Request.URL = endpoint restOptions.Settings = settings - httpRequest, err := c.createRequest(function.Request, headers, nil) + httpRequest, err := c.createRequest(function.Request, endpoint, headers, nil) if err != nil { return nil, err } diff --git a/rest/request.go b/rest/request.go index 13b4382..9961966 100644 --- a/rest/request.go +++ b/rest/request.go @@ -16,7 +16,7 @@ import ( "github.com/hasura/ndc-sdk-go/utils" ) -func (c *RESTConnector) createRequest(rawRequest *rest.Request, headers http.Header, arguments map[string]any) (*internal.RetryableRequest, error) { +func (c *RESTConnector) createRequest(rawRequest *rest.Request, endpoint string, headers http.Header, arguments map[string]any) (*internal.RetryableRequest, error) { var buffer io.ReadSeeker contentType := contentTypeJSON bodyData, ok := arguments["body"] @@ -66,7 +66,7 @@ func (c *RESTConnector) createRequest(rawRequest *rest.Request, headers http.Hea } request := &internal.RetryableRequest{ - URL: rawRequest.URL, + URL: endpoint, RawRequest: rawRequest, ContentType: contentType, Headers: headers,