Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add protocol helpers to infer procedure and type #756

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 32 additions & 28 deletions error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,7 @@ func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
}

func (w *ErrorWriter) classifyRequest(request *http.Request) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
case isGet:
if err := connectCheckProtocolVersion(request, w.requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
return classifyRequest(request, w.requireConnectProtocolHeader)
}

// IsSupported checks whether a request is using one of the ErrorWriter's
Expand Down Expand Up @@ -177,3 +150,34 @@ func (w *ErrorWriter) writeGRPCWeb(response http.ResponseWriter, err error) erro
response.WriteHeader(http.StatusOK)
return nil
}

func classifyRequest(request *http.Request, requireConnectProtocolHeader bool) protocolType {
ctype := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
isPost := request.Method == http.MethodPost
isGet := request.Method == http.MethodGet
switch {
case isPost && (ctype == grpcContentTypeDefault || strings.HasPrefix(ctype, grpcContentTypePrefix)):
return grpcProtocol
case isPost && (ctype == grpcWebContentTypeDefault || strings.HasPrefix(ctype, grpcWebContentTypePrefix)):
return grpcWebProtocol
case isPost && strings.HasPrefix(ctype, connectStreamingContentTypePrefix):
// Streaming ignores the requireConnectProtocolHeader option as the
// Content-Type is enough to determine the protocol.
if err := connectCheckProtocolVersion(request, false /* required */); err != nil {
return unknownProtocol
}
return connectStreamProtocol
case isPost && strings.HasPrefix(ctype, connectUnaryContentTypePrefix):
if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
case isGet:
if err := connectCheckProtocolVersion(request, requireConnectProtocolHeader); err != nil {
return unknownProtocol
}
return connectUnaryProtocol
default:
return unknownProtocol
}
}
39 changes: 39 additions & 0 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,45 @@ const (

var errNoTimeout = errors.New("no timeout")

// InferProtocolFromRequest returns the inferred protocol name for parsing an
// HTTP request. It inspects the request's method and headers to determine the
// protocol. If the request doesn't match any known protocol, an empty string
// is returned.
func InferProtocolFromRequest(request *http.Request) string {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • The rest of the API uses For and not From
  • would this be better named "ProtocolForRequest"?
  • would it be better to return string, bool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The For is used to distinguish streams that are to be consumed by the client rather than the handler. So I think the From is better here even tho its not currently in the API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind they are not necessarily synonyms where you'd always use one over another for consistency.
In my mind, the distinction is that For in a function name is used when the function constructs a new value that corresponds to its input or even wraps its input; From is used when the function extracts a value out of the inputs. In that case, From seems appropriate to me.

As far as returning string vs. (string, bool): another option is to add a ProtocolUnknown constant, so that callers can use a switch and don't need to examine two values. If that approach sounds good, the constant value might need to be the empty string so that an uninitialized Peer object reports ProtocolUnknown. WDYT?

switch classifyRequest(request, false) {
case connectUnaryProtocol, connectStreamProtocol:
return ProtocolConnect
case grpcProtocol:
return ProtocolGRPC
case grpcWebProtocol:
return ProtocolGRPCWeb
case unknownProtocol:
return ""
default:
return ""
}
}

// InferProcedureFromURL returns the inferred procedure name from a URL. It's
// returned in the form "/service/method" if a valid suffix is found. If the
// path doesn't contain a service and method, the entire path is returned.
func InferProcedureFromURL(url *url.URL) string {
path := strings.TrimSuffix(url.Path, "/")
ultimate := strings.LastIndex(path, "/")
if ultimate < 0 {
return url.Path
}
penultimate := strings.LastIndex(path[:ultimate], "/")
if penultimate < 0 {
return url.Path
}
procedure := path[penultimate:]
if len(procedure) < 4 { // two slashes + service + method
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
return url.Path
}
return procedure
}

// A Protocol defines the HTTP semantics to use when sending and receiving
// messages. It ties together codecs, compressors, and net/http to produce
// Senders and Receivers.
Expand Down
33 changes: 33 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package connect

import (
"net/url"
"testing"

"connectrpc.com/connect/internal/assert"
Expand Down Expand Up @@ -63,3 +64,35 @@ func BenchmarkCanonicalizeContentType(b *testing.B) {
b.ReportAllocs()
})
}

func TestProcedureFromURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
want string
}{
{name: "simple", url: "http://localhost:8080/foo", want: "/foo"},
{name: "service", url: "http://localhost:8080/service/bar", want: "/service/bar"},
{name: "trailing", url: "http://localhost:8080/service/bar/", want: "/service/bar"},
{name: "subroute", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"},
{name: "subrouteTrailing", url: "http://localhost:8080/api/service/bar/", want: "/service/bar"},
{
name: "real",
url: "http://localhost:8080/connect.ping.v1.PingService/Ping",
want: "/connect.ping.v1.PingService/Ping",
},
}
for _, testcase := range tests {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
url, err := url.Parse(testcase.url)
if !assert.Nil(t, err) {
return
}
t.Log(url.String())
assert.Equal(t, InferProcedureFromURL(url), testcase.want)
})
}
}