Skip to content

Commit

Permalink
op-txproxy: enable auth checks + jsonrpc responses when auth fails (#82)
Browse files Browse the repository at this point in the history
* re-enable auth

* forward auth errors so we get a valid json-rpc response for them

* typo
  • Loading branch information
hamdiallam authored Nov 7, 2024
1 parent 7d21ec3 commit ed410c0
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 40 deletions.
38 changes: 25 additions & 13 deletions op-txproxy/auth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package op_txproxy
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
Expand All @@ -16,6 +17,11 @@ var (
defaultBodyLimit = 5 * 1024 * 1024 // default in op-geth

DefaultAuthHeaderKey = "X-Optimism-Signature"

// errs
misformattedAuthErr = errors.New("misformatted auth header")
invalidAuthSignatureErr = errors.New("invalid auth signature")
mismatchedRecoveredSignerErr = errors.New("mismatched recovered signer")
)

type authHandler struct {
Expand All @@ -24,11 +30,13 @@ type authHandler struct {
}

// This middleware detects when authentication information is present on the request. If
// so, it will validate and set the caller in the request context. It does not reject
// if authentication information is missing. It is up to the request handler to do so via
// the missing `AuthContext`
// - NOTE: only up to the default body limit (5MB) is read when constructing the text hash
// that is signed over by the caller
// so, it will validate and set the caller in the request context. It does not reject any
// requests and leaves it up to the request handler to do so.
// 1. Missing Auth Header: AuthContext is missing from context
// 2. Failed Validation: AuthContext is set with a populated Err
// 3. Passed Validation: AuthContext is set with the authenticated caller
//
// note: only up to the default body limit (5MB) is read when constructing the text hash
func AuthMiddleware(headerKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return &authHandler{headerKey, next}
Expand All @@ -39,6 +47,7 @@ type authContextKey struct{}

type AuthContext struct {
Caller common.Address
Err error
}

// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler
Expand All @@ -50,16 +59,17 @@ func (h *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
authElems := strings.Split(authHeader, ":")
if len(authElems) != 2 {
http.Error(w, "misformatted auth header", http.StatusBadRequest)
newCtx := context.WithValue(r.Context(), authContextKey{}, &AuthContext{common.Address{}, misformattedAuthErr})
h.next.ServeHTTP(w, r.WithContext(newCtx))
return
}

if r.Body == nil {
// edge case from unit tests
if r.Body == nil { // edge case from unit tests
r.Body = io.NopCloser(bytes.NewBuffer(nil))
}

// Since this middleware runs prior to the server, we need to manually apply the body limit when reading.
// Since this middleware runs prior to the server, we need to manually apply the body limit when
// reading. We reject if we fail to read since this is an issue with this request
bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, int64(defaultBodyLimit)))
if err != nil {
http.Error(w, "unable to parse request body", http.StatusInternalServerError)
Expand All @@ -77,18 +87,20 @@ func (h *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
txtHash := accounts.TextHash(bodyBytes)
caller, signature := common.HexToAddress(authElems[0]), common.FromHex(authElems[1])
sigPubKey, err := crypto.SigToPub(txtHash, signature)
if err != nil {
http.Error(w, "invalid authentication signature", http.StatusBadRequest)
if sigPubKey == nil || err != nil {
newCtx := context.WithValue(r.Context(), authContextKey{}, &AuthContext{common.Address{}, invalidAuthSignatureErr})
h.next.ServeHTTP(w, r.WithContext(newCtx))
return
}

if caller != crypto.PubkeyToAddress(*sigPubKey) {
http.Error(w, "mismatched recovered signer", http.StatusBadRequest)
newCtx := context.WithValue(r.Context(), authContextKey{}, &AuthContext{common.Address{}, mismatchedRecoveredSignerErr})
h.next.ServeHTTP(w, r.WithContext(newCtx))
return
}

// Set the authenticated caller in the context
newCtx := context.WithValue(r.Context(), authContextKey{}, &AuthContext{caller})
newCtx := context.WithValue(r.Context(), authContextKey{}, &AuthContext{caller, nil})
h.next.ServeHTTP(w, r.WithContext(newCtx))
}

Expand Down
80 changes: 54 additions & 26 deletions op-txproxy/auth_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package op_txproxy
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -20,62 +21,75 @@ import (
"github.com/stretchr/testify/require"
)

var pingHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ping")
})

func TestAuthHandlerMissingAuth(t *testing.T) {
handler := authHandler{next: pingHandler}
var authContext *AuthContext
handler := authHandler{headerKey: "auth", next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authContext = AuthFromContext(r.Context())
})}

rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
handler.ServeHTTP(rr, r)

// simply forwards the request
require.Equal(t, http.StatusOK, rr.Code)
require.Equal(t, "ping", rr.Body.String())
handler.ServeHTTP(rr, r)
require.Nil(t, authContext)
}

func TestAuthHandlerBadHeader(t *testing.T) {
handler := authHandler{headerKey: "auth", next: pingHandler}
var authContext *AuthContext
handler := authHandler{headerKey: "auth", next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authContext = AuthFromContext(r.Context())
})}

rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("auth", "foobarbaz")

handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusBadRequest, rr.Code)
require.NotNil(t, authContext)
require.Zero(t, authContext.Caller)
require.Equal(t, misformattedAuthErr, authContext.Err)
}

func TestAuthHandlerBadSignature(t *testing.T) {
handler := authHandler{headerKey: "auth", next: pingHandler}
var authContext *AuthContext
handler := authHandler{headerKey: "auth", next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authContext = AuthFromContext(r.Context())
})}

rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", nil)
r.Header.Set("auth", fmt.Sprintf("%s:%s", common.HexToAddress("0xa"), "foobar"))
r.Header.Set("auth", fmt.Sprintf("%s:%s", common.HexToAddress("a"), "foobar"))

handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusBadRequest, rr.Code)
require.NotNil(t, authContext)
require.Zero(t, authContext.Caller)
require.Equal(t, invalidAuthSignatureErr, authContext.Err)
}

func TestAuthHandlerMismatchedCaller(t *testing.T) {
handler := authHandler{headerKey: "auth", next: pingHandler}
var authContext *AuthContext
handler := authHandler{headerKey: "auth", next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authContext = AuthFromContext(r.Context())
})}

rr := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/", strings.NewReader("body"))
body := bytes.NewBufferString("body")
r, _ := http.NewRequest("GET", "/", body)

privKey, _ := crypto.GenerateKey()
sig, _ := crypto.Sign(accounts.TextHash([]byte("body")), privKey)
r.Header.Set("auth", fmt.Sprintf("%s:%s", common.HexToAddress("0xa"), sig))
sig, _ := crypto.Sign(accounts.TextHash(body.Bytes()), privKey)
r.Header.Set("auth", fmt.Sprintf("%s:%s", common.HexToAddress("a"), common.Bytes2Hex(sig)))

handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusBadRequest, rr.Code)
require.NotNil(t, authContext)
require.Zero(t, authContext.Caller)
require.Equal(t, mismatchedRecoveredSignerErr, authContext.Err)
}

func TestAuthHandlerSetContext(t *testing.T) {
var ctx *AuthContext
var authContext *AuthContext
ctxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx = AuthFromContext(r.Context())
authContext = AuthFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})

Expand All @@ -92,13 +106,23 @@ func TestAuthHandlerSetContext(t *testing.T) {

handler.ServeHTTP(rr, r)
require.Equal(t, http.StatusOK, rr.Code)
require.Nil(t, authContext.Err)
require.Equal(t, addr, authContext.Caller)
}

type AuthAwareRPC struct{}

require.NotNil(t, ctx)
require.Equal(t, addr, ctx.Caller)
func (a *AuthAwareRPC) Run(ctx context.Context) error {
authContext := AuthFromContext(ctx)
if authContext == nil || authContext.Err != nil {
return errors.New("failed")
}
return nil
}

func TestAuthHandlerRpcMiddleware(t *testing.T) {
rpcServer := oprpc.NewServer("127.0.0.1", 0, "", oprpc.WithMiddleware(AuthMiddleware("auth")))
apis := []rpc.API{{Namespace: "test", Service: &AuthAwareRPC{}}}
rpcServer := oprpc.NewServer("127.0.0.1", 0, "", oprpc.WithAPIs(apis), oprpc.WithMiddleware(AuthMiddleware("auth")))
require.NoError(t, rpcServer.Start())
t.Cleanup(func() { _ = rpcServer.Stop() })

Expand All @@ -107,13 +131,17 @@ func TestAuthHandlerRpcMiddleware(t *testing.T) {
require.NoError(t, err)
defer clnt.Close()

// pass without auth (default handler does not deny)
// passthrough auth (default handler does not deny)
err = clnt.CallContext(context.Background(), nil, "rpc_modules")
require.Nil(t, err)

// denied with no header
err = clnt.CallContext(context.Background(), nil, "test_run")
require.NotNil(t, err)

// denied with bad auth header
clnt.SetHeader("auth", "foobar")
err = clnt.CallContext(context.Background(), nil, "rpc_modules")
err = clnt.CallContext(context.Background(), nil, "test_run")
require.NotNil(t, err)
}

Expand Down
7 changes: 6 additions & 1 deletion op-txproxy/conditional_txs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ var (
failedValidationErr = &rpc.JsonError{Message: "failed conditional validation", Code: params.TransactionConditionalRejectedErrCode}
maxCostExceededErr = &rpc.JsonError{Message: "max cost exceeded", Code: params.TransactionConditionalRejectedErrCode}
missingAuthenticationErr = &rpc.JsonError{Message: "missing authentication", Code: params.TransactionConditionalRejectedErrCode}
invalidAuthenticationErr = &rpc.JsonError{Message: "invalid authentication", Code: params.TransactionConditionalRejectedErrCode}
)

type ConditionalTxService struct {
Expand Down Expand Up @@ -94,6 +95,10 @@ func (s *ConditionalTxService) SendRawTransactionConditional(ctx context.Context
s.failures.WithLabelValues("missing auth").Inc()
return common.Hash{}, missingAuthenticationErr
}
if authInfo.Err != nil {
s.failures.WithLabelValues("invalid auth").Inc()
return common.Hash{}, invalidAuthenticationErr
}

// Handle the request. For now, we do nothing with the authenticated signer
hash, err := s.sendCondTx(ctx, authInfo.Caller, txBytes, &cond)
Expand Down Expand Up @@ -122,7 +127,7 @@ func (s *ConditionalTxService) sendCondTx(ctx context.Context, caller common.Add
return txHash, failedValidationErr
}
if cost > params.TransactionConditionalMaxCost {
s.log.Info("max cost exceeded", "cost", cost, "max", params.TransactionConditionalMaxCost, "caller", caller.String())
s.log.Info("conditional max cost exceeded", "cost", cost, "max", params.TransactionConditionalMaxCost, "caller", caller.String())
return txHash, maxCostExceededErr
}

Expand Down

0 comments on commit ed410c0

Please sign in to comment.