Skip to content

Commit

Permalink
Add protocol helpers to infer procedure and type
Browse files Browse the repository at this point in the history
Two new methods are added to allow for inferring the procedure and
protocol type of a request. These are provided to be used with http
middleware to deduce information about the requests. For example,
authentication middleware may wish to block on certain protocols or
to conditionally allow routes.

Signed-off-by: Edward McFarlane <[email protected]>
  • Loading branch information
emcfarlane committed Jun 25, 2024
1 parent 193652d commit 23627e1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 28 deletions.
60 changes: 32 additions & 28 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
}

func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
case isGet:
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
return classifyRequest(request, w.requireConnectProtocolHeader)
}

// IsSupported checks whether a request is using one of the ErrorWriter's
Expand Down Expand Up @@ -177,3 +150,34 @@ func (w *ErrorWriter) writeGRPCWeb(response http.ResponseWriter, err error) erro
response.WriteHeader(http.StatusOK)
return nil
}

func classifyRequest(request *http.Request, requireConnectProtocolHeader bool) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
case isGet:
if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
}
38 changes: 38 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,44 @@ const (

var errNoTimeout = errors.New("no timeout")

// ProtocolFromRequest returns the protocol name for an HTTP request. It
// inspects the request's method and headers to determine the protocol. If the
// request doesn't match any known protocol, an empty string is returned.
func ProtocolFromRequest(request *http.Request) string {
switch classifyRequest(request, false) {
case connectUnaryProtocol, connectStreamProtocol:
return ProtocolConnect
case grpcProtocol:
return ProtocolGRPC
case grpcWebProtocol:
return ProtocolGRPCWeb
case unknownProtocol:
return ""
default:
return ""
}
}

// ProcedureFromURL returns the procedure name for a URL. It's returned in the
// form "/service/method". If the path doesn't contain a service and method, the
// entire path is returned.
func ProcedureFromURL(url *url.URL) string {
path := strings.TrimSuffix(url.Path, "/")
ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return url.Path
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return url.Path
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
return url.Path
}
return procedure
}

// A Protocol defines the HTTP semantics to use when sending and receiving
// messages. It ties together codecs, compressors, and net/http to produce
// Senders and Receivers.
Expand Down
33 changes: 33 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package connect

import (
"net/url"
"testing"

"connectrpc.com/connect/internal/assert"
Expand Down Expand Up @@ -63,3 +64,35 @@ func BenchmarkCanonicalizeContentType(b *testing.B) {
b.ReportAllocs()
})
}

func TestProcedureFromURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
want string
}{
{name: "simple", url: "http://localhost:8080/foo", want: "/foo"},
{name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar"},
{name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar"},
{name: "subroute", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"},
{name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"},
{
name: "real",
url: "http://localhost:8080/connect.ping.v1.PingService/Ping",
want: "/connect.ping.v1.PingService/Ping",
},
}
for _, testcase := range tests {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
url, err := url.Parse(testcase.url)
if !assert.Nil(t, err) {
return
}
t.Log(url.String())
assert.Equal(t, ProcedureFromURL(url), testcase.want)
})
}
}

0 comments on commit 23627e1

Please sign in to comment.