diff --git a/protocol.go b/protocol.go index 8e20f9a7..f2ef8444 100644 --- a/protocol.go +++ b/protocol.go @@ -86,7 +86,7 @@ func ProcedureFromURL(url *url.URL) (string, bool) { if ultimate == len(path)-1 || penultimate == ultimate-1 { return url.Path, false } - return procedure, false + return procedure, true } // A Protocol defines the HTTP semantics to use when sending and receiving diff --git a/protocol_test.go b/protocol_test.go index 7bacea43..fe381dc9 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -15,6 +15,8 @@ package connect import ( + "net/http" + "net/http/httptest" "net/url" "testing" @@ -65,6 +67,78 @@ func BenchmarkCanonicalizeContentType(b *testing.B) { }) } +func TestProtocolFromRequest(t *testing.T) { + t.Parallel() + tests := []struct { + name string + contentType string + method string + want string + valid bool + }{{ + name: "connectUnary", + contentType: "application/json", + method: http.MethodPost, + want: ProtocolConnect, + valid: true, + }, { + name: "connectStreaming", + contentType: "application/connec+json", + method: http.MethodPost, + want: ProtocolConnect, + valid: true, + }, { + name: "grpcWeb", + contentType: "application/grpc-web", + method: http.MethodPost, + want: ProtocolGRPCWeb, + valid: true, + }, { + name: "grpc", + contentType: "application/grpc", + method: http.MethodPost, + want: ProtocolGRPC, + valid: true, + }, { + name: "connectGet", + contentType: "application/connec+json", + method: http.MethodGet, + want: ProtocolConnect, + valid: true, + }, { + name: "grpcWebGet", + contentType: "application/grpc-web", + method: http.MethodGet, + want: ProtocolConnect, + valid: true, + }, { + name: "grpcGet", + contentType: "application/grpc+json", + method: http.MethodGet, + want: ProtocolConnect, + valid: true, + }, { + name: "unknown", + contentType: "text/html", + method: http.MethodPost, + valid: false, + }} + for _, testcase := range tests { + testcase := testcase + t.Run(testcase.name, func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(testcase.method, "http://localhost:8080/service/Method", nil) + if testcase.contentType != "" { + req.Header.Set("Content-Type", testcase.contentType) + } + req.Method = testcase.method + got, valid := ProtocolFromRequest(req) + assert.Equal(t, got, testcase.want, assert.Sprintf("protocol")) + assert.Equal(t, valid, testcase.valid, assert.Sprintf("valid")) + }) + } +} + func TestProcedureFromURL(t *testing.T) { t.Parallel() tests := []struct { @@ -96,8 +170,9 @@ func TestProcedureFromURL(t *testing.T) { return } t.Log(url.String()) - got, _ := ProcedureFromURL(url) + got, valid := ProcedureFromURL(url) assert.Equal(t, got, testcase.want) + assert.Equal(t, valid, testcase.valid) }) } }