diff --git a/.gitignore b/.gitignore index a88ff6946..6b26bb045 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ vendor .idea .vscode .DS_Store +server/certs/*.key +server/certs/*.crt +server/certs/*.csr +server/certs/*.srl diff --git a/CHANGELOG.md b/CHANGELOG.md index a687e2c55..f7758f6c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -127,6 +127,7 @@ * [ENHANCEMENT] BasicLifecycler: Added `RingTokenGenerator` configuration that specifies the `TokenGenerator` implementation that is used for token generation. Default value is nil, meaning that `RandomTokenGenerator` is used. #323 * [ENHANCEMENT] Ring: add support for hedging to `DoUntilQuorum` when request minimization is enabled. #330 * [ENHANCEMENT] Lifecycler: allow instances to register in ascending order of ids in case of spread minimizing token generation strategy. #326 +* [ENHANCEMENT] Remove dependency on `github.com/weaveworks/common` package by migrating code to a corresponding package in `github.com/grafana/dskit`. #342 * [BUGFIX] spanlogger: Support multiple tenant IDs. #59 * [BUGFIX] Memberlist: fixed corrupted packets when sending compound messages with more than 255 messages or messages bigger than 64KB. #85 * [BUGFIX] Ring: `ring_member_ownership_percent` and `ring_tokens_owned` metrics are not updated on scale down. #109 diff --git a/crypto/tls/test/tls_integration_test.go b/crypto/tls/test/tls_integration_test.go index bff15eb88..c61bc6c34 100644 --- a/crypto/tls/test/tls_integration_test.go +++ b/crypto/tls/test/tls_integration_test.go @@ -19,13 +19,14 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/weaveworks/common/server" "golang.org/x/time/rate" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" + "github.com/grafana/dskit/server" + "github.com/grafana/dskit/backoff" "github.com/grafana/dskit/crypto/tls" ) diff --git a/errors/error.go b/errors/error.go new file mode 100644 index 000000000..1d36f7c39 --- /dev/null +++ b/errors/error.go @@ -0,0 +1,10 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/errors/error.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package errors + +// Error see https://dave.cheney.net/2016/04/07/constant-errors. +type Error string + +func (e Error) Error() string { return string(e) } diff --git a/go.mod b/go.mod index 1a1378bbe..5365fa1d4 100644 --- a/go.mod +++ b/go.mod @@ -9,11 +9,15 @@ require ( github.com/cespare/xxhash v1.1.0 github.com/cristalhq/hedgedhttp v0.7.0 github.com/facette/natsort v0.0.0-20181210072756-2cd4dd1e2dcb + github.com/felixge/httpsnoop v1.0.3 github.com/go-kit/log v0.2.1 github.com/go-redis/redis/v8 v8.11.5 + github.com/gogo/googleapis v1.1.0 github.com/gogo/protobuf v1.3.2 github.com/gogo/status v1.1.0 + github.com/golang/protobuf v1.5.3 github.com/golang/snappy v0.0.4 + github.com/gorilla/mux v1.8.0 github.com/grafana/gomemcache v0.0.0-20230316202710-a081dae0aba9 github.com/hashicorp/consul/api v1.15.3 github.com/hashicorp/go-cleanhttp v0.5.2 @@ -22,19 +26,27 @@ require ( github.com/hashicorp/memberlist v0.3.1 github.com/miekg/dns v1.1.50 github.com/opentracing-contrib/go-grpc v0.0.0-20210225150812-73cb765af46e + github.com/opentracing-contrib/go-stdlib v1.0.0 github.com/opentracing/opentracing-go v1.2.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.15.1 github.com/prometheus/client_model v0.4.0 github.com/prometheus/common v0.43.0 + github.com/prometheus/exporter-toolkit v0.8.2 + github.com/sercand/kuberesolver/v4 v4.0.0 + github.com/sirupsen/logrus v1.8.1 + github.com/soheilhy/cmux v0.1.5 github.com/stretchr/testify v1.8.1 - github.com/weaveworks/common v0.0.0-20230511094633-334485600903 + github.com/uber/jaeger-client-go v2.28.0+incompatible + github.com/uber/jaeger-lib v2.2.0+incompatible + github.com/weaveworks/promrus v1.2.0 go.etcd.io/etcd/api/v3 v3.5.0 go.etcd.io/etcd/client/pkg/v3 v3.5.0 go.etcd.io/etcd/client/v3 v3.5.0 go.uber.org/atomic v1.10.0 go.uber.org/goleak v1.2.0 golang.org/x/exp v0.0.0-20230321023759-10a507213a29 + golang.org/x/net v0.9.0 golang.org/x/sync v0.1.0 golang.org/x/time v0.1.0 google.golang.org/grpc v1.55.0 @@ -47,19 +59,16 @@ require ( github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.4.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fatih/color v1.13.0 // indirect - github.com/felixge/httpsnoop v1.0.3 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-logfmt/logfmt v0.5.1 // indirect - github.com/gogo/googleapis v1.1.0 // indirect - github.com/golang/protobuf v1.5.3 // indirect github.com/gomodule/redigo v1.8.9 // indirect github.com/google/btree v1.0.1 // indirect - github.com/gorilla/mux v1.8.0 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-hclog v0.14.1 // indirect github.com/hashicorp/go-immutable-radix v1.3.0 // indirect @@ -75,25 +84,16 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect github.com/onsi/gomega v1.24.0 // indirect - github.com/opentracing-contrib/go-stdlib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/prometheus/exporter-toolkit v0.8.2 // indirect github.com/prometheus/procfs v0.9.0 // indirect github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 // indirect - github.com/sercand/kuberesolver/v4 v4.0.0 // indirect - github.com/sirupsen/logrus v1.8.1 // indirect - github.com/soheilhy/cmux v0.1.5 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/stretchr/objx v0.5.0 // indirect - github.com/uber/jaeger-client-go v2.28.0+incompatible // indirect - github.com/uber/jaeger-lib v2.2.0+incompatible // indirect - github.com/weaveworks/promrus v1.2.0 // indirect github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.17.0 // indirect golang.org/x/crypto v0.1.0 // indirect golang.org/x/mod v0.8.0 // indirect - golang.org/x/net v0.9.0 // indirect golang.org/x/oauth2 v0.7.0 // indirect golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect diff --git a/go.sum b/go.sum index 299f4518b..fa8a54cbe 100644 --- a/go.sum +++ b/go.sum @@ -409,8 +409,6 @@ github.com/armon/go-metrics v0.3.10 h1:FR+drcQStOe+32sYyJYyZ7FIdgoGGBnwLl+flodp8 github.com/armon/go-metrics v0.3.10/go.mod h1:4O98XIr/9W0sxpJ8UaYkvjk10Iff7SnFrb4QAOwNTFc= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= -github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -422,7 +420,6 @@ github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91 github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= @@ -490,7 +487,6 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2 github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-kit/log v0.2.0/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= @@ -507,10 +503,8 @@ github.com/gogo/googleapis v1.1.0 h1:kFkMAZBNAn4j7K0GiZr8cRYzejq68VbheufiV3YuyFI github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/gogo/status v1.0.3/go.mod h1:SavQ51ycCLnc7dGyJxp8YAmudx8xqiVrRf+6IXRsugc= github.com/gogo/status v1.1.0 h1:+eIkrewn5q6b30y+g/BJINVVdi2xH7je5MPJ3ZPK3JA= github.com/gogo/status v1.1.0/go.mod h1:BFv9nrluPLmrS0EmGVvLaPNmRosr9KapBYd5/hpY1WM= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -609,7 +603,6 @@ github.com/googleapis/gax-go/v2 v2.6.0/go.mod h1:1mjbznJAPHFpesgE5ucqfYEscaz5kMd github.com/googleapis/gax-go/v2 v2.7.0/go.mod h1:TEop28CZZQ2y+c0VxMUmu1lV+fQx57QpBWsYpwqHJx8= github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= -github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/grafana/gomemcache v0.0.0-20230316202710-a081dae0aba9 h1:WB3bGH2f1UN6jkd6uAEWfHB8OD7dKJ0v2Oo6SNfhpfQ= @@ -666,19 +659,16 @@ github.com/hashicorp/serf v0.9.7/go.mod h1:TXZNMjZQijwlDvp+r0b63xZ45H7JmCmgg4gpT github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= -github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -702,7 +692,6 @@ github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= @@ -712,7 +701,6 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= -github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= @@ -730,7 +718,6 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= -github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f h1:KUppIJq7/+SVif2QVs3tOP0zanoHgBEVAwHxUSIzRqU= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -738,10 +725,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/gomega v1.24.0 h1:+0glovB9Jd6z3VR+ScSwQqXVTIfJcGA9UBM8yzQxhqg= github.com/onsi/gomega v1.24.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= -github.com/opentracing-contrib/go-grpc v0.0.0-20180928155321-4b5a12d3ff02/go.mod h1:JNdpVEzCpXBgIiv4ds+TzhN1hrtxq6ClLrTlT9OQRSc= github.com/opentracing-contrib/go-grpc v0.0.0-20210225150812-73cb765af46e h1:4cPxUYdgaGzZIT5/j0IfqOrrXmq6bG8AwvwisMXpdrg= github.com/opentracing-contrib/go-grpc v0.0.0-20210225150812-73cb765af46e/go.mod h1:DYR5Eij8rJl8h7gblRrOZ8g0kW1umSpKqYIBTgeDtLo= -github.com/opentracing-contrib/go-stdlib v0.0.0-20190519235532-cf7a6c988dc9/go.mod h1:PLldrQSroqzH70Xl+1DQcGnefIbqsKR7UDaiux3zV+w= github.com/opentracing-contrib/go-stdlib v1.0.0 h1:TBS7YuVotp8myLon4Pv7BtCBzOTo1DeZCld0Z63mW2w= github.com/opentracing-contrib/go-stdlib v1.0.0/go.mod h1:qtI1ogk+2JhVPIXVc6q+NHziSmy2W5GbdQZFUHADCBU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= @@ -766,24 +751,18 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= -github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= -github.com/prometheus/client_golang v1.13.0/go.mod h1:vTeo+zgvILHsnnj/39Ou/1fPN5nJFOEMgftOUOmlvYQ= -github.com/prometheus/client_golang v1.14.0/go.mod h1:8vpkKitgIVNcqrRBWh1C4TIUQgYNtG/XQE4E/Zae36Y= github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI= github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY= github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= -github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= -github.com/prometheus/common v0.37.0/go.mod h1:phzohg0JFMnBEFGxTDbfu3QyL5GI8gTQJFhYO5B3mfA= github.com/prometheus/common v0.43.0 h1:iq+BVjvYLei5f27wiuNiB1DN6DYQkp1c8Bx0Vykh5us= github.com/prometheus/common v0.43.0/go.mod h1:NCvr5cQIh3Y/gy73/RdVtC9r8xxrxwJnB+2lB3BxrFc= github.com/prometheus/exporter-toolkit v0.8.2 h1:sbJAfBXQFkG6sUkbwBun8MNdzW9+wd5YfPYofbmj0YM= @@ -793,8 +772,6 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI= github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -839,8 +816,6 @@ github.com/uber/jaeger-client-go v2.28.0+incompatible h1:G4QSBfvPKvg5ZM2j9MrJFdf github.com/uber/jaeger-client-go v2.28.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.2.0+incompatible h1:MxZXOiR2JuoANZ3J6DE/U0kSFv/eJ/GfSYVCjK7dyaw= github.com/uber/jaeger-lib v2.2.0+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= -github.com/weaveworks/common v0.0.0-20230511094633-334485600903 h1:ph7R2CS/0o1gBzpzK/CioUKJVsXNVXfDGR8FZ9rMZIw= -github.com/weaveworks/common v0.0.0-20230511094633-334485600903/go.mod h1:rgbeLfJUtEr+G74cwFPR1k/4N0kDeaeSv/qhUNE4hm8= github.com/weaveworks/promrus v1.2.0 h1:jOLf6pe6/vss4qGHjXmGz4oDJQA+AOCqEL3FvvZGz7M= github.com/weaveworks/promrus v1.2.0/go.mod h1:SaE82+OJ91yqjrE1rsvBWVzNZKcHYFtMUyS1+Ogs/KA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -867,7 +842,6 @@ go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= -go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -888,7 +862,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20221012134737-56aed061732a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -976,10 +949,8 @@ golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLd golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= @@ -1112,7 +1083,6 @@ golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1157,7 +1127,6 @@ golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= @@ -1173,7 +1142,6 @@ golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191113191852-77e3bb0ad9e7/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/grpcclient/instrumentation.go b/grpcclient/instrumentation.go index c8d352889..4a10ce48d 100644 --- a/grpcclient/instrumentation.go +++ b/grpcclient/instrumentation.go @@ -4,8 +4,9 @@ import ( otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" - "github.com/weaveworks/common/middleware" "google.golang.org/grpc" + + "github.com/grafana/dskit/middleware" ) func Instrument(requestDuration *prometheus.HistogramVec) ([]grpc.UnaryClientInterceptor, []grpc.StreamClientInterceptor) { diff --git a/grpcutil/cancel.go b/grpcutil/cancel.go new file mode 100644 index 000000000..b1d369d2a --- /dev/null +++ b/grpcutil/cancel.go @@ -0,0 +1,25 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/grpc/cancel.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package grpcutil + +import ( + "context" + "errors" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// IsCanceled checks whether an error comes from an operation being canceled +func IsCanceled(err error) bool { + if errors.Is(err, context.Canceled) { + return true + } + s, ok := status.FromError(err) + if ok && s.Code() == codes.Canceled { + return true + } + return false +} diff --git a/httpgrpc/README.md b/httpgrpc/README.md new file mode 100644 index 000000000..4e4d7fe3d --- /dev/null +++ b/httpgrpc/README.md @@ -0,0 +1,9 @@ +**What?** Embedding HTTP requests and responses into a gRPC service; a service and client to translate back and forth between the two, so you can use them with your preferred mux. + +**Why?** Get all the goodness of protobuf encoding, HTTP/2, snappy, load balancing, persistent connection and native Kubernetes load balancing with ~none of the effort. + +To rebuild generated protobuf code, run: + + protoc -I ./ --go_out=plugins=grpc:./ ./httpgrpc.proto + +Follow the instructions here to get a working protoc: https://github.com/gogo/protobuf diff --git a/httpgrpc/httpgrpc.go b/httpgrpc/httpgrpc.go new file mode 100644 index 000000000..050492dfc --- /dev/null +++ b/httpgrpc/httpgrpc.go @@ -0,0 +1,59 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/httpgrpc/httpgrpc.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package httpgrpc + +import ( + "fmt" + + spb "github.com/gogo/googleapis/google/rpc" + "github.com/gogo/protobuf/types" + "github.com/gogo/status" + log "github.com/sirupsen/logrus" +) + +// Errorf returns a HTTP gRPC error than is correctly forwarded over +// gRPC, and can eventually be converted back to a HTTP response with +// HTTPResponseFromError. +func Errorf(code int, tmpl string, args ...interface{}) error { + return ErrorFromHTTPResponse(&HTTPResponse{ + Code: int32(code), + Body: []byte(fmt.Sprintf(tmpl, args...)), + }) +} + +// ErrorFromHTTPResponse converts an HTTP response into a grpc error +func ErrorFromHTTPResponse(resp *HTTPResponse) error { + a, err := types.MarshalAny(resp) + if err != nil { + return err + } + + return status.ErrorProto(&spb.Status{ + Code: resp.Code, + Message: string(resp.Body), + Details: []*types.Any{a}, + }) +} + +// HTTPResponseFromError converts a grpc error into an HTTP response +func HTTPResponseFromError(err error) (*HTTPResponse, bool) { + s, ok := status.FromError(err) + if !ok { + return nil, false + } + + status := s.Proto() + if len(status.Details) != 1 { + return nil, false + } + + var resp HTTPResponse + if err := types.UnmarshalAny(status.Details[0], &resp); err != nil { + log.Errorf("Got error containing non-response: %v", err) + return nil, false + } + + return &resp, true +} diff --git a/httpgrpc/httpgrpc.pb.go b/httpgrpc/httpgrpc.pb.go new file mode 100644 index 000000000..bab0efd53 --- /dev/null +++ b/httpgrpc/httpgrpc.pb.go @@ -0,0 +1,1311 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: httpgrpc.proto + +package httpgrpc + +import ( + bytes "bytes" + context "context" + fmt "fmt" + _ "github.com/gogo/protobuf/gogoproto" + proto "github.com/gogo/protobuf/proto" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + io "io" + math "math" + math_bits "math/bits" + reflect "reflect" + strings "strings" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type HTTPRequest struct { + Method string `protobuf:"bytes,1,opt,name=method,proto3" json:"method,omitempty"` + Url string `protobuf:"bytes,2,opt,name=url,proto3" json:"url,omitempty"` + Headers []*Header `protobuf:"bytes,3,rep,name=headers,proto3" json:"headers,omitempty"` + Body []byte `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"` +} + +func (m *HTTPRequest) Reset() { *m = HTTPRequest{} } +func (*HTTPRequest) ProtoMessage() {} +func (*HTTPRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_c50820dbc814fcdd, []int{0} +} +func (m *HTTPRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *HTTPRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_HTTPRequest.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *HTTPRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_HTTPRequest.Merge(m, src) +} +func (m *HTTPRequest) XXX_Size() int { + return m.Size() +} +func (m *HTTPRequest) XXX_DiscardUnknown() { + xxx_messageInfo_HTTPRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_HTTPRequest proto.InternalMessageInfo + +func (m *HTTPRequest) GetMethod() string { + if m != nil { + return m.Method + } + return "" +} + +func (m *HTTPRequest) GetUrl() string { + if m != nil { + return m.Url + } + return "" +} + +func (m *HTTPRequest) GetHeaders() []*Header { + if m != nil { + return m.Headers + } + return nil +} + +func (m *HTTPRequest) GetBody() []byte { + if m != nil { + return m.Body + } + return nil +} + +type HTTPResponse struct { + Code int32 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` + Headers []*Header `protobuf:"bytes,2,rep,name=headers,proto3" json:"headers,omitempty"` + Body []byte `protobuf:"bytes,3,opt,name=body,proto3" json:"body,omitempty"` +} + +func (m *HTTPResponse) Reset() { *m = HTTPResponse{} } +func (*HTTPResponse) ProtoMessage() {} +func (*HTTPResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_c50820dbc814fcdd, []int{1} +} +func (m *HTTPResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *HTTPResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_HTTPResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *HTTPResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_HTTPResponse.Merge(m, src) +} +func (m *HTTPResponse) XXX_Size() int { + return m.Size() +} +func (m *HTTPResponse) XXX_DiscardUnknown() { + xxx_messageInfo_HTTPResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_HTTPResponse proto.InternalMessageInfo + +func (m *HTTPResponse) GetCode() int32 { + if m != nil { + return m.Code + } + return 0 +} + +func (m *HTTPResponse) GetHeaders() []*Header { + if m != nil { + return m.Headers + } + return nil +} + +func (m *HTTPResponse) GetBody() []byte { + if m != nil { + return m.Body + } + return nil +} + +type Header struct { + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Values []string `protobuf:"bytes,2,rep,name=values,proto3" json:"values,omitempty"` +} + +func (m *Header) Reset() { *m = Header{} } +func (*Header) ProtoMessage() {} +func (*Header) Descriptor() ([]byte, []int) { + return fileDescriptor_c50820dbc814fcdd, []int{2} +} +func (m *Header) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Header) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Header.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Header) XXX_Merge(src proto.Message) { + xxx_messageInfo_Header.Merge(m, src) +} +func (m *Header) XXX_Size() int { + return m.Size() +} +func (m *Header) XXX_DiscardUnknown() { + xxx_messageInfo_Header.DiscardUnknown(m) +} + +var xxx_messageInfo_Header proto.InternalMessageInfo + +func (m *Header) GetKey() string { + if m != nil { + return m.Key + } + return "" +} + +func (m *Header) GetValues() []string { + if m != nil { + return m.Values + } + return nil +} + +func init() { + proto.RegisterType((*HTTPRequest)(nil), "httpgrpc.HTTPRequest") + proto.RegisterType((*HTTPResponse)(nil), "httpgrpc.HTTPResponse") + proto.RegisterType((*Header)(nil), "httpgrpc.Header") +} + +func init() { proto.RegisterFile("httpgrpc.proto", fileDescriptor_c50820dbc814fcdd) } + +var fileDescriptor_c50820dbc814fcdd = []byte{ + // 301 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x91, 0xbd, 0x4e, 0xc3, 0x30, + 0x14, 0x85, 0xed, 0xa6, 0x04, 0xea, 0x56, 0xa8, 0xb2, 0xa0, 0x8a, 0x3a, 0x5c, 0x55, 0x99, 0x22, + 0x86, 0x22, 0x05, 0x16, 0x46, 0x60, 0xc9, 0x88, 0xac, 0xbe, 0x40, 0x42, 0xac, 0x44, 0x22, 0xd4, + 0x21, 0x3f, 0xa0, 0x6e, 0x3c, 0x02, 0x8f, 0xc1, 0xa3, 0x30, 0x66, 0xec, 0x48, 0x9c, 0x85, 0xb1, + 0x8f, 0x80, 0xec, 0xa4, 0x10, 0x31, 0xb1, 0x9d, 0x7b, 0xee, 0x51, 0xbe, 0x7b, 0x62, 0x72, 0x1c, + 0x17, 0x45, 0x1a, 0x65, 0xe9, 0xfd, 0x32, 0xcd, 0x44, 0x21, 0xe8, 0xd1, 0x7e, 0x9e, 0x9f, 0x44, + 0x22, 0x12, 0xda, 0x3c, 0x57, 0xaa, 0xdd, 0xdb, 0x2f, 0x64, 0xec, 0xad, 0x56, 0x77, 0x8c, 0x3f, + 0x95, 0x3c, 0x2f, 0xe8, 0x8c, 0x98, 0x8f, 0xbc, 0x88, 0x45, 0x68, 0xe1, 0x05, 0x76, 0x46, 0xac, + 0x9b, 0xe8, 0x94, 0x18, 0x65, 0x96, 0x58, 0x03, 0x6d, 0x2a, 0x49, 0xcf, 0xc8, 0x61, 0xcc, 0xfd, + 0x90, 0x67, 0xb9, 0x65, 0x2c, 0x0c, 0x67, 0xec, 0x4e, 0x97, 0x3f, 0x68, 0x4f, 0x2f, 0xd8, 0x3e, + 0x40, 0x29, 0x19, 0x06, 0x22, 0xdc, 0x58, 0xc3, 0x05, 0x76, 0x26, 0x4c, 0x6b, 0x3b, 0x20, 0x93, + 0x16, 0x9c, 0xa7, 0x62, 0x9d, 0x73, 0x95, 0xb9, 0x15, 0x21, 0xd7, 0xdc, 0x03, 0xa6, 0x75, 0x9f, + 0x31, 0xf8, 0x2f, 0xc3, 0xe8, 0x31, 0x5c, 0x62, 0xb6, 0x31, 0x75, 0xff, 0x03, 0xdf, 0x74, 0xa5, + 0x94, 0x54, 0x4d, 0x9f, 0xfd, 0xa4, 0xe4, 0xed, 0xa7, 0x47, 0xac, 0x9b, 0xdc, 0x6b, 0x32, 0x54, + 0x77, 0xd1, 0x2b, 0x62, 0x7a, 0xfe, 0x3a, 0x4c, 0x38, 0x3d, 0xed, 0x41, 0x7f, 0x7f, 0xd5, 0x7c, + 0xf6, 0xd7, 0x6e, 0x8b, 0xd8, 0xe8, 0xe6, 0xb2, 0xaa, 0x01, 0x6d, 0x6b, 0x40, 0xbb, 0x1a, 0xf0, + 0xab, 0x04, 0xfc, 0x2e, 0x01, 0x7f, 0x48, 0xc0, 0x95, 0x04, 0xfc, 0x29, 0x01, 0x7f, 0x49, 0x40, + 0x3b, 0x09, 0xf8, 0xad, 0x01, 0x54, 0x35, 0x80, 0xb6, 0x0d, 0xa0, 0xc0, 0xd4, 0x0f, 0x72, 0xf1, + 0x1d, 0x00, 0x00, 0xff, 0xff, 0x44, 0x0e, 0x7c, 0xff, 0xc2, 0x01, 0x00, 0x00, +} + +func (this *HTTPRequest) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*HTTPRequest) + if !ok { + that2, ok := that.(HTTPRequest) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.Method != that1.Method { + return false + } + if this.Url != that1.Url { + return false + } + if len(this.Headers) != len(that1.Headers) { + return false + } + for i := range this.Headers { + if !this.Headers[i].Equal(that1.Headers[i]) { + return false + } + } + if !bytes.Equal(this.Body, that1.Body) { + return false + } + return true +} +func (this *HTTPResponse) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*HTTPResponse) + if !ok { + that2, ok := that.(HTTPResponse) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.Code != that1.Code { + return false + } + if len(this.Headers) != len(that1.Headers) { + return false + } + for i := range this.Headers { + if !this.Headers[i].Equal(that1.Headers[i]) { + return false + } + } + if !bytes.Equal(this.Body, that1.Body) { + return false + } + return true +} +func (this *Header) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*Header) + if !ok { + that2, ok := that.(Header) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.Key != that1.Key { + return false + } + if len(this.Values) != len(that1.Values) { + return false + } + for i := range this.Values { + if this.Values[i] != that1.Values[i] { + return false + } + } + return true +} +func (this *HTTPRequest) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 8) + s = append(s, "&httpgrpc.HTTPRequest{") + s = append(s, "Method: "+fmt.Sprintf("%#v", this.Method)+",\n") + s = append(s, "Url: "+fmt.Sprintf("%#v", this.Url)+",\n") + if this.Headers != nil { + s = append(s, "Headers: "+fmt.Sprintf("%#v", this.Headers)+",\n") + } + s = append(s, "Body: "+fmt.Sprintf("%#v", this.Body)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func (this *HTTPResponse) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 7) + s = append(s, "&httpgrpc.HTTPResponse{") + s = append(s, "Code: "+fmt.Sprintf("%#v", this.Code)+",\n") + if this.Headers != nil { + s = append(s, "Headers: "+fmt.Sprintf("%#v", this.Headers)+",\n") + } + s = append(s, "Body: "+fmt.Sprintf("%#v", this.Body)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func (this *Header) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 6) + s = append(s, "&httpgrpc.Header{") + s = append(s, "Key: "+fmt.Sprintf("%#v", this.Key)+",\n") + s = append(s, "Values: "+fmt.Sprintf("%#v", this.Values)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func valueToGoStringHttpgrpc(v interface{}, typ string) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// HTTPClient is the client API for HTTP service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type HTTPClient interface { + Handle(ctx context.Context, in *HTTPRequest, opts ...grpc.CallOption) (*HTTPResponse, error) +} + +type hTTPClient struct { + cc *grpc.ClientConn +} + +func NewHTTPClient(cc *grpc.ClientConn) HTTPClient { + return &hTTPClient{cc} +} + +func (c *hTTPClient) Handle(ctx context.Context, in *HTTPRequest, opts ...grpc.CallOption) (*HTTPResponse, error) { + out := new(HTTPResponse) + err := c.cc.Invoke(ctx, "/httpgrpc.HTTP/Handle", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// HTTPServer is the server API for HTTP service. +type HTTPServer interface { + Handle(context.Context, *HTTPRequest) (*HTTPResponse, error) +} + +// UnimplementedHTTPServer can be embedded to have forward compatible implementations. +type UnimplementedHTTPServer struct { +} + +func (*UnimplementedHTTPServer) Handle(ctx context.Context, req *HTTPRequest) (*HTTPResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Handle not implemented") +} + +func RegisterHTTPServer(s *grpc.Server, srv HTTPServer) { + s.RegisterService(&_HTTP_serviceDesc, srv) +} + +func _HTTP_Handle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HTTPRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HTTPServer).Handle(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/httpgrpc.HTTP/Handle", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HTTPServer).Handle(ctx, req.(*HTTPRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _HTTP_serviceDesc = grpc.ServiceDesc{ + ServiceName: "httpgrpc.HTTP", + HandlerType: (*HTTPServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Handle", + Handler: _HTTP_Handle_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "httpgrpc.proto", +} + +func (m *HTTPRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *HTTPRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *HTTPRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Body) > 0 { + i -= len(m.Body) + copy(dAtA[i:], m.Body) + i = encodeVarintHttpgrpc(dAtA, i, uint64(len(m.Body))) + i-- + dAtA[i] = 0x22 + } + if len(m.Headers) > 0 { + for iNdEx := len(m.Headers) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.Headers[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintHttpgrpc(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x1a + } + } + if len(m.Url) > 0 { + i -= len(m.Url) + copy(dAtA[i:], m.Url) + i = encodeVarintHttpgrpc(dAtA, i, uint64(len(m.Url))) + i-- + dAtA[i] = 0x12 + } + if len(m.Method) > 0 { + i -= len(m.Method) + copy(dAtA[i:], m.Method) + i = encodeVarintHttpgrpc(dAtA, i, uint64(len(m.Method))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *HTTPResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *HTTPResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *HTTPResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Body) > 0 { + i -= len(m.Body) + copy(dAtA[i:], m.Body) + i = encodeVarintHttpgrpc(dAtA, i, uint64(len(m.Body))) + i-- + dAtA[i] = 0x1a + } + if len(m.Headers) > 0 { + for iNdEx := len(m.Headers) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.Headers[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintHttpgrpc(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x12 + } + } + if m.Code != 0 { + i = encodeVarintHttpgrpc(dAtA, i, uint64(m.Code)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *Header) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Header) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Header) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Values) > 0 { + for iNdEx := len(m.Values) - 1; iNdEx >= 0; iNdEx-- { + i -= len(m.Values[iNdEx]) + copy(dAtA[i:], m.Values[iNdEx]) + i = encodeVarintHttpgrpc(dAtA, i, uint64(len(m.Values[iNdEx]))) + i-- + dAtA[i] = 0x12 + } + } + if len(m.Key) > 0 { + i -= len(m.Key) + copy(dAtA[i:], m.Key) + i = encodeVarintHttpgrpc(dAtA, i, uint64(len(m.Key))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintHttpgrpc(dAtA []byte, offset int, v uint64) int { + offset -= sovHttpgrpc(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *HTTPRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Method) + if l > 0 { + n += 1 + l + sovHttpgrpc(uint64(l)) + } + l = len(m.Url) + if l > 0 { + n += 1 + l + sovHttpgrpc(uint64(l)) + } + if len(m.Headers) > 0 { + for _, e := range m.Headers { + l = e.Size() + n += 1 + l + sovHttpgrpc(uint64(l)) + } + } + l = len(m.Body) + if l > 0 { + n += 1 + l + sovHttpgrpc(uint64(l)) + } + return n +} + +func (m *HTTPResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Code != 0 { + n += 1 + sovHttpgrpc(uint64(m.Code)) + } + if len(m.Headers) > 0 { + for _, e := range m.Headers { + l = e.Size() + n += 1 + l + sovHttpgrpc(uint64(l)) + } + } + l = len(m.Body) + if l > 0 { + n += 1 + l + sovHttpgrpc(uint64(l)) + } + return n +} + +func (m *Header) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Key) + if l > 0 { + n += 1 + l + sovHttpgrpc(uint64(l)) + } + if len(m.Values) > 0 { + for _, s := range m.Values { + l = len(s) + n += 1 + l + sovHttpgrpc(uint64(l)) + } + } + return n +} + +func sovHttpgrpc(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozHttpgrpc(x uint64) (n int) { + return sovHttpgrpc(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (this *HTTPRequest) String() string { + if this == nil { + return "nil" + } + repeatedStringForHeaders := "[]*Header{" + for _, f := range this.Headers { + repeatedStringForHeaders += strings.Replace(f.String(), "Header", "Header", 1) + "," + } + repeatedStringForHeaders += "}" + s := strings.Join([]string{`&HTTPRequest{`, + `Method:` + fmt.Sprintf("%v", this.Method) + `,`, + `Url:` + fmt.Sprintf("%v", this.Url) + `,`, + `Headers:` + repeatedStringForHeaders + `,`, + `Body:` + fmt.Sprintf("%v", this.Body) + `,`, + `}`, + }, "") + return s +} +func (this *HTTPResponse) String() string { + if this == nil { + return "nil" + } + repeatedStringForHeaders := "[]*Header{" + for _, f := range this.Headers { + repeatedStringForHeaders += strings.Replace(f.String(), "Header", "Header", 1) + "," + } + repeatedStringForHeaders += "}" + s := strings.Join([]string{`&HTTPResponse{`, + `Code:` + fmt.Sprintf("%v", this.Code) + `,`, + `Headers:` + repeatedStringForHeaders + `,`, + `Body:` + fmt.Sprintf("%v", this.Body) + `,`, + `}`, + }, "") + return s +} +func (this *Header) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&Header{`, + `Key:` + fmt.Sprintf("%v", this.Key) + `,`, + `Values:` + fmt.Sprintf("%v", this.Values) + `,`, + `}`, + }, "") + return s +} +func valueToStringHttpgrpc(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("*%v", pv) +} +func (m *HTTPRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: HTTPRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: HTTPRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Method", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Method = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Url", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Url = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Headers", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Headers = append(m.Headers, &Header{}) + if err := m.Headers[len(m.Headers)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Body", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Body = append(m.Body[:0], dAtA[iNdEx:postIndex]...) + if m.Body == nil { + m.Body = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipHttpgrpc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthHttpgrpc + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthHttpgrpc + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *HTTPResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: HTTPResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: HTTPResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Code", wireType) + } + m.Code = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Code |= int32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Headers", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Headers = append(m.Headers, &Header{}) + if err := m.Headers[len(m.Headers)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Body", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Body = append(m.Body[:0], dAtA[iNdEx:postIndex]...) + if m.Body == nil { + m.Body = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipHttpgrpc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthHttpgrpc + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthHttpgrpc + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Header) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Header: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Header: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Key", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Key = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Values", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthHttpgrpc + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthHttpgrpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Values = append(m.Values, string(dAtA[iNdEx:postIndex])) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipHttpgrpc(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthHttpgrpc + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthHttpgrpc + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipHttpgrpc(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthHttpgrpc + } + iNdEx += length + if iNdEx < 0 { + return 0, ErrInvalidLengthHttpgrpc + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowHttpgrpc + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipHttpgrpc(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + if iNdEx < 0 { + return 0, ErrInvalidLengthHttpgrpc + } + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthHttpgrpc = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowHttpgrpc = fmt.Errorf("proto: integer overflow") +) diff --git a/httpgrpc/httpgrpc.proto b/httpgrpc/httpgrpc.proto new file mode 100644 index 000000000..8f546330a --- /dev/null +++ b/httpgrpc/httpgrpc.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package httpgrpc; + +import "gogoproto/gogo.proto"; + +option (gogoproto.equal_all) = true; +option (gogoproto.gostring_all) = true; +option (gogoproto.stringer_all) = true; +option (gogoproto.goproto_stringer_all) = false; +option (gogoproto.goproto_unkeyed_all) = false; +option (gogoproto.goproto_unrecognized_all) = false; +option (gogoproto.goproto_sizecache_all) = false; + +service HTTP { + rpc Handle(HTTPRequest) returns (HTTPResponse) {}; +} + +message HTTPRequest { + string method = 1; + string url = 2; + repeated Header headers = 3; + bytes body = 4; +} + +message HTTPResponse { + int32 Code = 1; + repeated Header headers = 2; + bytes body = 3; +} + +message Header { + string key = 1; + repeated string values = 2; +} diff --git a/httpgrpc/server/server.go b/httpgrpc/server/server.go new file mode 100644 index 000000000..7b715bd56 --- /dev/null +++ b/httpgrpc/server/server.go @@ -0,0 +1,235 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/httpgrpc/server/server.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package server + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + + otgrpc "github.com/opentracing-contrib/go-grpc" + "github.com/opentracing/opentracing-go" + "github.com/sercand/kuberesolver/v4" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/dskit/log" + "github.com/grafana/dskit/middleware" +) + +// Server implements HTTPServer. HTTPServer is a generated interface that gRPC +// servers must implement. +type Server struct { + handler http.Handler +} + +// NewServer makes a new Server. +func NewServer(handler http.Handler) *Server { + return &Server{ + handler: handler, + } +} + +type nopCloser struct { + *bytes.Buffer +} + +func (nopCloser) Close() error { return nil } + +// BytesBuffer returns the underlaying `bytes.buffer` used to build this io.ReadCloser. +func (n nopCloser) BytesBuffer() *bytes.Buffer { return n.Buffer } + +// Handle implements HTTPServer. +func (s Server) Handle(ctx context.Context, r *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) { + req, err := http.NewRequest(r.Method, r.Url, nopCloser{Buffer: bytes.NewBuffer(r.Body)}) + if err != nil { + return nil, err + } + toHeader(r.Headers, req.Header) + req = req.WithContext(ctx) + req.RequestURI = r.Url + req.ContentLength = int64(len(r.Body)) + + recorder := httptest.NewRecorder() + s.handler.ServeHTTP(recorder, req) + resp := &httpgrpc.HTTPResponse{ + Code: int32(recorder.Code), + Headers: fromHeader(recorder.Header()), + Body: recorder.Body.Bytes(), + } + if recorder.Code/100 == 5 { + return nil, httpgrpc.ErrorFromHTTPResponse(resp) + } + return resp, nil +} + +// Client is a http.Handler that forwards the request over gRPC. +type Client struct { + client httpgrpc.HTTPClient + conn *grpc.ClientConn +} + +// ParseURL deals with direct:// style URLs, as well as kubernetes:// urls. +// For backwards compatibility it treats URLs without schems as kubernetes://. +func ParseURL(unparsed string) (string, error) { + // if it has :///, this is the kuberesolver v2 URL. Return it as it is. + if strings.Contains(unparsed, ":///") { + return unparsed, nil + } + + parsed, err := url.Parse(unparsed) + if err != nil { + return "", err + } + + scheme, host := parsed.Scheme, parsed.Host + if !strings.Contains(unparsed, "://") { + scheme, host = "kubernetes", unparsed + } + + switch scheme { + case "direct": + return host, err + + case "kubernetes": + host, port, err := net.SplitHostPort(host) + if err != nil { + return "", err + } + parts := strings.SplitN(host, ".", 3) + service, domain := parts[0], "" + if len(parts) > 1 { + namespace := parts[1] + domain = "." + namespace + } + if len(parts) > 2 { + domain = domain + "." + parts[2] + } + address := fmt.Sprintf("kubernetes:///%s%s:%s", service, domain, port) + return address, nil + + default: + return "", fmt.Errorf("unrecognised scheme: %s", parsed.Scheme) + } +} + +// NewClient makes a new Client, given a kubernetes service address. +func NewClient(address string) (*Client, error) { + kuberesolver.RegisterInCluster() + + address, err := ParseURL(address) + if err != nil { + return nil, err + } + const grpcServiceConfig = `{"loadBalancingPolicy":"round_robin"}` + + dialOptions := []grpc.DialOption{ + grpc.WithDefaultServiceConfig(grpcServiceConfig), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithChainUnaryInterceptor( + otgrpc.OpenTracingClientInterceptor(opentracing.GlobalTracer()), + middleware.ClientUserHeaderInterceptor, + ), + } + + conn, err := grpc.Dial(address, dialOptions...) + if err != nil { + return nil, err + } + + return &Client{ + client: httpgrpc.NewHTTPClient(conn), + conn: conn, + }, nil +} + +// HTTPRequest wraps an ordinary HTTPRequest with a gRPC one +func HTTPRequest(r *http.Request) (*httpgrpc.HTTPRequest, error) { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } + return &httpgrpc.HTTPRequest{ + Method: r.Method, + Url: r.RequestURI, + Body: body, + Headers: fromHeader(r.Header), + }, nil +} + +// WriteResponse converts an httpgrpc response to an HTTP one +func WriteResponse(w http.ResponseWriter, resp *httpgrpc.HTTPResponse) error { + toHeader(resp.Headers, w.Header()) + w.WriteHeader(int(resp.Code)) + _, err := w.Write(resp.Body) + return err +} + +// WriteError converts an httpgrpc error to an HTTP one +func WriteError(w http.ResponseWriter, err error) { + resp, ok := httpgrpc.HTTPResponseFromError(err) + if ok { + _ = WriteResponse(w, resp) + } else { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// ServeHTTP implements http.Handler +func (c *Client) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if tracer := opentracing.GlobalTracer(); tracer != nil { + if span := opentracing.SpanFromContext(r.Context()); span != nil { + if err := tracer.Inject(span.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header)); err != nil { + log.Global().Warnf("Failed to inject tracing headers into request: %v", err) + } + } + } + + req, err := HTTPRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + resp, err := c.client.Handle(r.Context(), req) + if err != nil { + // Some errors will actually contain a valid resp, just need to unpack it + var ok bool + resp, ok = httpgrpc.HTTPResponseFromError(err) + + if !ok { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + if err := WriteResponse(w, resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +func toHeader(hs []*httpgrpc.Header, header http.Header) { + for _, h := range hs { + header[h.Key] = h.Values + } +} + +func fromHeader(hs http.Header) []*httpgrpc.Header { + result := make([]*httpgrpc.Header, 0, len(hs)) + for k, vs := range hs { + result = append(result, &httpgrpc.Header{ + Key: k, + Values: vs, + }) + } + return result +} diff --git a/httpgrpc/server/server_test.go b/httpgrpc/server/server_test.go new file mode 100644 index 000000000..1c32a1a0c --- /dev/null +++ b/httpgrpc/server/server_test.go @@ -0,0 +1,153 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/httpgrpc/server/server_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package server + +import ( + "bytes" + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + + opentracing "github.com/opentracing/opentracing-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + jaegercfg "github.com/uber/jaeger-client-go/config" + "google.golang.org/grpc" + + "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/dskit/middleware" + "github.com/grafana/dskit/user" +) + +type testServer struct { + *Server + URL string + grpcServer *grpc.Server +} + +func newTestServer(t *testing.T, handler http.Handler) (*testServer, error) { + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, err + } + + server := &testServer{ + Server: NewServer(handler), + grpcServer: grpc.NewServer(), + URL: "direct://" + lis.Addr().String(), + } + + httpgrpc.RegisterHTTPServer(server.grpcServer, server.Server) + go func() { + require.NoError(t, server.grpcServer.Serve(lis)) + }() + + return server, nil +} + +func TestBasic(t *testing.T) { + server, err := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := fmt.Fprint(w, "world") + require.NoError(t, err) + })) + require.NoError(t, err) + defer server.grpcServer.GracefulStop() + + client, err := NewClient(server.URL) + require.NoError(t, err) + + req, err := http.NewRequest("GET", "/hello", &bytes.Buffer{}) + require.NoError(t, err) + + req = req.WithContext(user.InjectOrgID(context.Background(), "1")) + recorder := httptest.NewRecorder() + client.ServeHTTP(recorder, req) + + assert.Equal(t, "world", recorder.Body.String()) + assert.Equal(t, 200, recorder.Code) +} + +func TestError(t *testing.T) { + server, err := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Does a Fprintln, injecting a newline. + http.Error(w, "foo", http.StatusInternalServerError) + })) + require.NoError(t, err) + defer server.grpcServer.GracefulStop() + + client, err := NewClient(server.URL) + require.NoError(t, err) + + req, err := http.NewRequest("GET", "/hello", &bytes.Buffer{}) + require.NoError(t, err) + + req = req.WithContext(user.InjectOrgID(context.Background(), "1")) + recorder := httptest.NewRecorder() + client.ServeHTTP(recorder, req) + + assert.Equal(t, "foo\n", recorder.Body.String()) + assert.Equal(t, 500, recorder.Code) +} + +func TestParseURL(t *testing.T) { + for _, tc := range []struct { + input string + expected string + err string + }{ + {"direct://foo", "foo", ""}, + {"kubernetes://foo:123", "kubernetes:///foo:123", ""}, + {"querier.cortex:995", "kubernetes:///querier.cortex:995", ""}, + {"foo.bar.svc.local:995", "kubernetes:///foo.bar.svc.local:995", ""}, + {"kubernetes:///foo:123", "kubernetes:///foo:123", ""}, + {"dns:///foo.bar.svc.local:995", "dns:///foo.bar.svc.local:995", ""}, + {"monster://foo:995", "", "unrecognised scheme: monster"}, + } { + got, err := ParseURL(tc.input) + if tc.err == "" { + require.NoError(t, err) + } else { + require.EqualError(t, err, tc.err) + } + assert.Equal(t, tc.expected, got) + } +} + +func TestTracePropagation(t *testing.T) { + jaeger := jaegercfg.Configuration{} + closer, err := jaeger.InitGlobalTracer("test") + require.NoError(t, err) + defer closer.Close() + + server, err := newTestServer(t, middleware.Tracer{}.Wrap( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + span := opentracing.SpanFromContext(r.Context()) + _, err := fmt.Fprint(w, span.BaggageItem("name")) + require.NoError(t, err) + }), + )) + + require.NoError(t, err) + defer server.grpcServer.GracefulStop() + + client, err := NewClient(server.URL) + require.NoError(t, err) + + req, err := http.NewRequest("GET", "/hello", &bytes.Buffer{}) + require.NoError(t, err) + + sp, ctx := opentracing.StartSpanFromContext(context.Background(), "Test") + sp.SetBaggageItem("name", "world") + + req = req.WithContext(user.InjectOrgID(ctx, "1")) + recorder := httptest.NewRecorder() + client.ServeHTTP(recorder, req) + + assert.Equal(t, "world", recorder.Body.String()) + assert.Equal(t, 200, recorder.Code) +} diff --git a/httpgrpc/tools.go b/httpgrpc/tools.go new file mode 100644 index 000000000..9117d39b1 --- /dev/null +++ b/httpgrpc/tools.go @@ -0,0 +1,10 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/httpgrpc/tools.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package httpgrpc + +import ( + // This is a workaround for go mod which fails to download gogoproto otherwise + _ "github.com/gogo/protobuf/gogoproto" +) diff --git a/instrument/instrument.go b/instrument/instrument.go new file mode 100644 index 000000000..4ea480b29 --- /dev/null +++ b/instrument/instrument.go @@ -0,0 +1,192 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/instrument/instrument.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +//lint:file-ignore faillint Changing from prometheus to promauto package would be a breaking change for consumers + +package instrument + +import ( + "context" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" + otlog "github.com/opentracing/opentracing-go/log" + "github.com/prometheus/client_golang/prometheus" + + "github.com/grafana/dskit/grpcutil" + "github.com/grafana/dskit/tracing" + "github.com/grafana/dskit/user" +) + +// DefBuckets are histogram buckets for the response time (in seconds) +// of a network service, including one that is responding very slowly. +var DefBuckets = []float64{.005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 25, 50, 100} + +// Collector describes something that collects data before and/or after a task. +type Collector interface { + Register() + Before(ctx context.Context, method string, start time.Time) + After(ctx context.Context, method, statusCode string, start time.Time) +} + +// HistogramCollector collects the duration of a request +type HistogramCollector struct { + metric *prometheus.HistogramVec +} + +// HistogramCollectorBuckets define the buckets when passing the metric +var HistogramCollectorBuckets = []string{"operation", "status_code"} + +// NewHistogramCollectorFromOpts creates a Collector from histogram options. +// It makes sure that the buckets are named properly and should be preferred over +// NewHistogramCollector(). +func NewHistogramCollectorFromOpts(opts prometheus.HistogramOpts) *HistogramCollector { + metric := prometheus.NewHistogramVec(opts, HistogramCollectorBuckets) + return &HistogramCollector{metric} +} + +// NewHistogramCollector creates a Collector from a metric. +func NewHistogramCollector(metric *prometheus.HistogramVec) *HistogramCollector { + return &HistogramCollector{metric} +} + +// Register registers metrics. +func (c *HistogramCollector) Register() { + prometheus.MustRegister(c.metric) +} + +// Before collects for the upcoming request. +func (c *HistogramCollector) Before(context.Context, string, time.Time) { +} + +// After collects when the request is done. +func (c *HistogramCollector) After(ctx context.Context, method, statusCode string, start time.Time) { + if c.metric != nil { + ObserveWithExemplar(ctx, c.metric.WithLabelValues(method, statusCode), time.Since(start).Seconds()) + } +} + +// ObserveWithExemplar adds a sample to a histogram, and adds an exemplar if the context has a sampled trace. +// 'histogram' parameter must be castable to prometheus.ExemplarObserver or function will panic +// (this will always work for a HistogramVec). +func ObserveWithExemplar(ctx context.Context, histogram prometheus.Observer, seconds float64) { + if traceID, ok := tracing.ExtractSampledTraceID(ctx); ok { + histogram.(prometheus.ExemplarObserver).ObserveWithExemplar( + seconds, + prometheus.Labels{"traceID": traceID}, + ) + return + } + histogram.Observe(seconds) +} + +// JobCollector collects metrics for jobs. Designed for batch jobs which run on a regular, +// not-too-frequent, non-overlapping interval. We can afford to measure duration directly +// with gauges, and compute quantile with quantile_over_time. +type JobCollector struct { + start, end, duration *prometheus.GaugeVec + started, completed *prometheus.CounterVec +} + +// NewJobCollector instantiates JobCollector which creates its metrics. +func NewJobCollector(namespace string) *JobCollector { + return &JobCollector{ + start: prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: "job", + Name: "latest_start_timestamp", + Help: "Unix UTC timestamp of most recent job start time", + }, []string{"operation"}), + end: prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: "job", + Name: "latest_end_timestamp", + Help: "Unix UTC timestamp of most recent job end time", + }, []string{"operation", "status_code"}), + duration: prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: "job", + Name: "latest_duration_seconds", + Help: "duration of most recent job", + }, []string{"operation", "status_code"}), + started: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: "job", + Name: "started_total", + Help: "Number of jobs started", + }, []string{"operation"}), + completed: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: "job", + Name: "completed_total", + Help: "Number of jobs completed", + }, []string{"operation", "status_code"}), + } +} + +// Register registers metrics. +func (c *JobCollector) Register() { + prometheus.MustRegister(c.start) + prometheus.MustRegister(c.end) + prometheus.MustRegister(c.duration) + prometheus.MustRegister(c.started) + prometheus.MustRegister(c.completed) +} + +// Before collects for the upcoming request. +func (c *JobCollector) Before(_ context.Context, method string, start time.Time) { + c.start.WithLabelValues(method).Set(float64(start.UTC().Unix())) + c.started.WithLabelValues(method).Inc() +} + +// After collects when the request is done. +func (c *JobCollector) After(_ context.Context, method, statusCode string, start time.Time) { + end := time.Now() + c.end.WithLabelValues(method, statusCode).Set(float64(end.UTC().Unix())) + c.duration.WithLabelValues(method, statusCode).Set(end.Sub(start).Seconds()) + c.completed.WithLabelValues(method, statusCode).Inc() +} + +// CollectedRequest runs a tracked request. It uses the given Collector to monitor requests. +// +// If `f` returns no error we log "200" as status code, otherwise "500". Pass in a function +// for `toStatusCode` to overwrite this behaviour. It will also emit an OpenTracing span if +// you have a global tracer configured. +func CollectedRequest(ctx context.Context, method string, col Collector, toStatusCode func(error) string, f func(context.Context) error) error { + if toStatusCode == nil { + toStatusCode = ErrorCode + } + sp, newCtx := opentracing.StartSpanFromContext(ctx, method) + ext.SpanKindRPCClient.Set(sp) + if userID, err := user.ExtractUserID(ctx); err == nil { + sp.SetTag("user", userID) + } + if orgID, err := user.ExtractOrgID(ctx); err == nil { + sp.SetTag("organization", orgID) + } + + start := time.Now() + col.Before(newCtx, method, start) + err := f(newCtx) + col.After(newCtx, method, toStatusCode(err), start) + + if err != nil { + if !grpcutil.IsCanceled(err) { + ext.Error.Set(sp, true) + } + sp.LogFields(otlog.Error(err)) + } + sp.Finish() + + return err +} + +// ErrorCode converts an error into an HTTP status code +func ErrorCode(err error) string { + if err == nil { + return "200" + } + return "500" +} diff --git a/instrument/instrument_test.go b/instrument/instrument_test.go new file mode 100644 index 000000000..9d0cc2904 --- /dev/null +++ b/instrument/instrument_test.go @@ -0,0 +1,76 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/instrument/instrument_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +//lint:file-ignore faillint Changing from prometheus to promauto package would be a breaking change for consumers + +package instrument_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + + "github.com/grafana/dskit/instrument" +) + +func TestNewHistogramCollector(t *testing.T) { + m := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "test", + Subsystem: "instrumentation", + Name: "foo", + Help: "", + Buckets: prometheus.DefBuckets, + }, instrument.HistogramCollectorBuckets) + c := instrument.NewHistogramCollector(m) + assert.NotNil(t, c) +} + +type spyCollector struct { + before bool + after bool + afterCode string +} + +func (c *spyCollector) Register() { +} + +// Before collects for the upcoming request. +func (c *spyCollector) Before(context.Context, string, time.Time) { + c.before = true +} + +// After collects when the request is done. +func (c *spyCollector) After(_ context.Context, _, statusCode string, _ time.Time) { + c.after = true + c.afterCode = statusCode +} + +func TestCollectedRequest(t *testing.T) { + c := &spyCollector{} + fcalled := false + err := instrument.CollectedRequest(context.Background(), "test", c, nil, func(_ context.Context) error { + fcalled = true + return nil + }) + assert.NoError(t, err) + assert.True(t, fcalled) + assert.True(t, c.before) + assert.True(t, c.after) + assert.Equal(t, "200", c.afterCode) +} + +func TestCollectedRequest_Error(t *testing.T) { + c := &spyCollector{} + err := instrument.CollectedRequest(context.Background(), "test", c, nil, func(_ context.Context) error { + return errors.New("boom") + }) + assert.EqualError(t, err, "boom") + assert.True(t, c.before) + assert.True(t, c.after) + assert.Equal(t, "500", c.afterCode) +} diff --git a/kv/consul/client.go b/kv/consul/client.go index 861e03f56..5501a67d8 100644 --- a/kv/consul/client.go +++ b/kv/consul/client.go @@ -14,9 +14,10 @@ import ( consul "github.com/hashicorp/consul/api" "github.com/hashicorp/go-cleanhttp" "github.com/prometheus/client_golang/prometheus" - "github.com/weaveworks/common/instrument" "golang.org/x/time/rate" + "github.com/grafana/dskit/instrument" + "github.com/grafana/dskit/backoff" "github.com/grafana/dskit/flagext" "github.com/grafana/dskit/kv/codec" diff --git a/kv/consul/metrics.go b/kv/consul/metrics.go index 52a1d4e84..166e79bc9 100644 --- a/kv/consul/metrics.go +++ b/kv/consul/metrics.go @@ -6,7 +6,8 @@ import ( consul "github.com/hashicorp/consul/api" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/weaveworks/common/instrument" + + "github.com/grafana/dskit/instrument" ) type consulInstrumentation struct { diff --git a/kv/metrics.go b/kv/metrics.go index 66fe9fa91..7361b8c41 100644 --- a/kv/metrics.go +++ b/kv/metrics.go @@ -6,8 +6,9 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/weaveworks/common/httpgrpc" - "github.com/weaveworks/common/instrument" + + "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/dskit/instrument" ) // RegistererWithKVName wraps the provided Registerer with the KV name label. If a nil reg diff --git a/log/format.go b/log/format.go new file mode 100644 index 000000000..3925b3c14 --- /dev/null +++ b/log/format.go @@ -0,0 +1,56 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/format.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +import ( + "flag" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +// Format is a settable identifier for the output format of logs +type Format struct { + s string + Logrus logrus.Formatter +} + +// RegisterFlags adds the log format flag to the provided flagset. +func (f *Format) RegisterFlags(fs *flag.FlagSet) { + _ = f.Set("logfmt") + fs.Var(f, "log.format", "Output log messages in the given format. Valid formats: [logfmt, json]") +} + +func (f Format) String() string { + return f.s +} + +// UnmarshalYAML implements yaml.Unmarshaler. +func (f *Format) UnmarshalYAML(unmarshal func(interface{}) error) error { + var format string + if err := unmarshal(&format); err != nil { + return err + } + return f.Set(format) +} + +// MarshalYAML implements yaml.Marshaler. +func (f Format) MarshalYAML() (interface{}, error) { + return f.String(), nil +} + +// Set updates the value of the output format. Implements flag.Value +func (f *Format) Set(s string) error { + switch s { + case "logfmt": + f.Logrus = &logrus.JSONFormatter{} + case "json": + f.Logrus = &logrus.JSONFormatter{} + default: + return errors.Errorf("unrecognized log format %q", s) + } + f.s = s + return nil +} diff --git a/log/global.go b/log/global.go new file mode 100644 index 000000000..68131a156 --- /dev/null +++ b/log/global.go @@ -0,0 +1,62 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/global.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +var global = Noop() + +// Global returns the global logger. +func Global() Interface { + return global +} + +// SetGlobal sets the global logger. +func SetGlobal(i Interface) { + global = i +} + +// Debugf convenience function calls the global loggerr. +func Debugf(format string, args ...interface{}) { + global.Debugf(format, args...) +} + +// Debugln convenience function calls the global logger. +func Debugln(args ...interface{}) { + global.Debugln(args...) +} + +// Infof convenience function calls the global logger. +func Infof(format string, args ...interface{}) { + global.Infof(format, args...) +} + +// Infoln convenience function calls the global logger. +func Infoln(args ...interface{}) { + global.Infoln(args...) +} + +// Warnf convenience function calls the global logger. +func Warnf(format string, args ...interface{}) { + global.Warnf(format, args...) +} + +// Warnln convenience function calls the global logger. +func Warnln(args ...interface{}) { + global.Warnln(args...) +} + +// Errorf convenience function calls the global logger. +func Errorf(format string, args ...interface{}) { + global.Errorf(format, args...) +} + +// Errorln convenience function calls the global logger. +func Errorln(args ...interface{}) { + global.Errorln(args...) +} + +// WithField convenience function calls the global logger. +func WithField(key string, value interface{}) Interface { + return global.WithField(key, value) +} diff --git a/log/gokit.go b/log/gokit.go new file mode 100644 index 000000000..c956c6775 --- /dev/null +++ b/log/gokit.go @@ -0,0 +1,106 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/gokit.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +import ( + "fmt" + "os" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" +) + +// NewGoKitFormat creates a new Interface backed by a GoKit logger +// format can be "json" or defaults to logfmt +func NewGoKitFormat(l Level, f Format) Interface { + var logger log.Logger + if f.s == "json" { + logger = log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) + } else { + logger = log.NewLogfmtLogger(log.NewSyncWriter(os.Stderr)) + } + return addStandardFields(logger, l) +} + +// stand-alone for test purposes +func addStandardFields(logger log.Logger, l Level) Interface { + logger = log.With(logger, "ts", log.DefaultTimestampUTC, "caller", log.Caller(5)) + logger = level.NewFilter(logger, l.Gokit) + return gokit{logger} +} + +// NewGoKit creates a new Interface backed by a GoKit logger +func NewGoKit(l Level) Interface { + return NewGoKitFormat(l, Format{s: "logfmt"}) +} + +// GoKit wraps an existing gokit Logger. +func GoKit(logger log.Logger) Interface { + return gokit{logger} +} + +type gokit struct { + log.Logger +} + +// Helper to defer sprintf until it is needed. +type sprintf struct { + format string + args []interface{} +} + +func (s *sprintf) String() string { + return fmt.Sprintf(s.format, s.args...) +} + +// Helper to defer sprint until it is needed. +// Note we don't use Sprintln because the output is passed to go-kit as one value among many on a line +type sprint struct { + args []interface{} +} + +func (s *sprint) String() string { + return fmt.Sprint(s.args...) +} + +func (g gokit) Debugf(format string, args ...interface{}) { + level.Debug(g.Logger).Log("msg", &sprintf{format: format, args: args}) +} +func (g gokit) Debugln(args ...interface{}) { + level.Debug(g.Logger).Log("msg", &sprint{args: args}) +} + +func (g gokit) Infof(format string, args ...interface{}) { + level.Info(g.Logger).Log("msg", &sprintf{format: format, args: args}) +} +func (g gokit) Infoln(args ...interface{}) { + level.Info(g.Logger).Log("msg", &sprint{args: args}) +} + +func (g gokit) Warnf(format string, args ...interface{}) { + level.Warn(g.Logger).Log("msg", &sprintf{format: format, args: args}) +} +func (g gokit) Warnln(args ...interface{}) { + level.Warn(g.Logger).Log("msg", &sprint{args: args}) +} + +func (g gokit) Errorf(format string, args ...interface{}) { + level.Error(g.Logger).Log("msg", &sprintf{format: format, args: args}) +} +func (g gokit) Errorln(args ...interface{}) { + level.Error(g.Logger).Log("msg", &sprint{args: args}) +} + +func (g gokit) WithField(key string, value interface{}) Interface { + return gokit{log.With(g.Logger, key, value)} +} + +func (g gokit) WithFields(fields Fields) Interface { + logger := g.Logger + for k, v := range fields { + logger = log.With(logger, k, v) + } + return gokit{logger} +} diff --git a/log/gokit_test.go b/log/gokit_test.go new file mode 100644 index 000000000..b1ca0fbdd --- /dev/null +++ b/log/gokit_test.go @@ -0,0 +1,29 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/gokit_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +import ( + "testing" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" +) + +func BenchmarkDebugf(b *testing.B) { + lvl := Level{Gokit: level.AllowInfo()} + g := log.NewNopLogger() + logger := addStandardFields(g, lvl) + // Simulate the parameters used in middleware/logging.go + var ( + method = "method" + uri = "https://example.com/foobar" + statusCode = 404 + duration = 42 + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Debugf("%s %s (%d) %s", method, uri, statusCode, duration) + } +} diff --git a/log/interface.go b/log/interface.go new file mode 100644 index 000000000..a074fef90 --- /dev/null +++ b/log/interface.go @@ -0,0 +1,28 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/interface.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +// Interface 'unifies' gokit logging and logrus logging, such that +// the middleware in this repo can be used in projects which use either +// loggers. +type Interface interface { + Debugf(format string, args ...interface{}) + Debugln(args ...interface{}) + + Infof(format string, args ...interface{}) + Infoln(args ...interface{}) + + Errorf(format string, args ...interface{}) + Errorln(args ...interface{}) + + Warnf(format string, args ...interface{}) + Warnln(args ...interface{}) + + WithField(key string, value interface{}) Interface + WithFields(Fields) Interface +} + +// Fields convenience type for adding multiple fields to a log statement. +type Fields map[string]interface{} diff --git a/log/level.go b/log/level.go new file mode 100644 index 000000000..f2b8db55b --- /dev/null +++ b/log/level.go @@ -0,0 +1,82 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/level.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +// Copy-pasted from prometheus/common/promlog. +// Copyright 2017 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ( + "flag" + + "github.com/go-kit/log/level" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +// Level is a settable identifier for the minimum level a log entry +// must be have. +type Level struct { + s string + Logrus logrus.Level + Gokit level.Option +} + +// RegisterFlags adds the log level flag to the provided flagset. +func (l *Level) RegisterFlags(f *flag.FlagSet) { + _ = l.Set("info") + f.Var(l, "log.level", "Only log messages with the given severity or above. Valid levels: [debug, info, warn, error]") +} + +func (l *Level) String() string { + return l.s +} + +// UnmarshalYAML implements yaml.Unmarshaler. +func (l *Level) UnmarshalYAML(unmarshal func(interface{}) error) error { + var level string + if err := unmarshal(&level); err != nil { + return err + } + return l.Set(level) +} + +// MarshalYAML implements yaml.Marshaler. +func (l Level) MarshalYAML() (interface{}, error) { + return l.String(), nil +} + +// Set updates the value of the allowed level. Implments flag.Value. +func (l *Level) Set(s string) error { + switch s { + case "debug": + l.Logrus = logrus.DebugLevel + l.Gokit = level.AllowDebug() + case "info": + l.Logrus = logrus.InfoLevel + l.Gokit = level.AllowInfo() + case "warn": + l.Logrus = logrus.WarnLevel + l.Gokit = level.AllowWarn() + case "error": + l.Logrus = logrus.ErrorLevel + l.Gokit = level.AllowError() + default: + return errors.Errorf("unrecognized log level %q", s) + } + + l.s = s + return nil +} diff --git a/log/level_test.go b/log/level_test.go new file mode 100644 index 000000000..d3532890b --- /dev/null +++ b/log/level_test.go @@ -0,0 +1,28 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/level_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +func TestMarshalYAML(t *testing.T) { + var l Level + err := l.Set("debug") + require.NoError(t, err) + + // Test the non-pointed to Level, as people might embed it. + y, err := yaml.Marshal(l) + require.NoError(t, err) + require.Equal(t, []byte("debug\n"), y) + + // And the pointed to Level. + y, err = yaml.Marshal(&l) + require.NoError(t, err) + require.Equal(t, []byte("debug\n"), y) +} diff --git a/log/logging.go b/log/logging.go new file mode 100644 index 000000000..5bb80dee2 --- /dev/null +++ b/log/logging.go @@ -0,0 +1,32 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/logging.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +import ( + "fmt" + "os" + + "github.com/sirupsen/logrus" + + "github.com/weaveworks/promrus" +) + +// Setup configures a global logrus logger to output to stderr. +// It populates the standard logrus logger as well as the global logging instance. +func Setup(logLevel string) error { + level, err := logrus.ParseLevel(logLevel) + if err != nil { + return fmt.Errorf("error parsing log level: %v", err) + } + hook, err := promrus.NewPrometheusHook() // Expose number of log messages as Prometheus metrics. + if err != nil { + return err + } + logrus.SetOutput(os.Stderr) + logrus.SetLevel(level) + logrus.AddHook(hook) + SetGlobal(Logrus(logrus.StandardLogger())) + return nil +} diff --git a/log/logrus.go b/log/logrus.go new file mode 100644 index 000000000..df0e1ae07 --- /dev/null +++ b/log/logrus.go @@ -0,0 +1,63 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/logrus.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +import ( + "os" + + "github.com/sirupsen/logrus" +) + +// NewLogrusFormat makes a new Interface backed by a logrus logger +// format can be "json" or defaults to logfmt +func NewLogrusFormat(level Level, f Format) Interface { + log := logrus.New() + log.Out = os.Stderr + log.Level = level.Logrus + log.Formatter = f.Logrus + return logrusLogger{log} +} + +// NewLogrus makes a new Interface backed by a logrus logger +func NewLogrus(level Level) Interface { + return NewLogrusFormat(level, Format{Logrus: &logrus.TextFormatter{}}) +} + +// Logrus wraps an existing Logrus logger. +func Logrus(l *logrus.Logger) Interface { + return logrusLogger{l} +} + +type logrusLogger struct { + *logrus.Logger +} + +func (l logrusLogger) WithField(key string, value interface{}) Interface { + return logrusEntry{ + Entry: l.Logger.WithField(key, value), + } +} + +func (l logrusLogger) WithFields(fields Fields) Interface { + return logrusEntry{ + Entry: l.Logger.WithFields(map[string]interface{}(fields)), + } +} + +type logrusEntry struct { + *logrus.Entry +} + +func (l logrusEntry) WithField(key string, value interface{}) Interface { + return logrusEntry{ + Entry: l.Entry.WithField(key, value), + } +} + +func (l logrusEntry) WithFields(fields Fields) Interface { + return logrusEntry{ + Entry: l.Entry.WithFields(map[string]interface{}(fields)), + } +} diff --git a/log/noop.go b/log/noop.go new file mode 100644 index 000000000..89d437468 --- /dev/null +++ b/log/noop.go @@ -0,0 +1,27 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/logging/noop.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package log + +// Noop logger. +func Noop() Interface { + return noop{} +} + +type noop struct{} + +func (noop) Debugf(string, ...interface{}) {} +func (noop) Debugln(...interface{}) {} +func (noop) Infof(string, ...interface{}) {} +func (noop) Infoln(...interface{}) {} +func (noop) Warnf(string, ...interface{}) {} +func (noop) Warnln(...interface{}) {} +func (noop) Errorf(string, ...interface{}) {} +func (noop) Errorln(...interface{}) {} +func (noop) WithField(string, interface{}) Interface { + return noop{} +} +func (noop) WithFields(Fields) Interface { + return noop{} +} diff --git a/middleware/counting_listener.go b/middleware/counting_listener.go new file mode 100644 index 000000000..961f71a51 --- /dev/null +++ b/middleware/counting_listener.go @@ -0,0 +1,47 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/counting_listener.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "net" + "sync" + + "github.com/prometheus/client_golang/prometheus" +) + +// CountingListener returns a Listener that increments a Prometheus gauge when +// a connection is accepted, and decrements the gauge when the connection is closed. +func CountingListener(l net.Listener, g prometheus.Gauge) net.Listener { + return &countingListener{Listener: l, gauge: g} +} + +type countingListener struct { + net.Listener + gauge prometheus.Gauge +} + +func (c *countingListener) Accept() (net.Conn, error) { + conn, err := c.Listener.Accept() + if err != nil { + return nil, err + } + c.gauge.Inc() + return &countingListenerConn{Conn: conn, gauge: c.gauge}, nil +} + +type countingListenerConn struct { + net.Conn + gauge prometheus.Gauge + once sync.Once +} + +func (l *countingListenerConn) Close() error { + err := l.Conn.Close() + + // Only ever decrement the gauge once in case of badly behaving callers. + l.once.Do(func() { l.gauge.Dec() }) + + return err +} diff --git a/middleware/counting_listener_test.go b/middleware/counting_listener_test.go new file mode 100644 index 000000000..d98584e5b --- /dev/null +++ b/middleware/counting_listener_test.go @@ -0,0 +1,80 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/counting_listener_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "errors" + "net" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" +) + +type fakeListener struct { + net.Listener + acceptErr error + closeErr error +} + +type fakeConn struct { + net.Conn + closeErr error +} + +func (c *fakeConn) Close() error { + return c.closeErr +} + +func (c *fakeListener) Accept() (net.Conn, error) { + return &fakeConn{closeErr: c.closeErr}, c.acceptErr +} + +func TestCountingListener(t *testing.T) { + reg := prometheus.NewPedanticRegistry() + g := promauto.With(reg).NewGauge(prometheus.GaugeOpts{ + Namespace: "test", + Name: "gauge", + }) + + fake := &fakeListener{} + l := CountingListener(fake, g) + assert.Equal(t, float64(0), testutil.ToFloat64(g)) + + // Accepting connections should increment the gauge. + c1, err := l.Accept() + assert.NoError(t, err) + assert.Equal(t, float64(1), testutil.ToFloat64(g)) + c2, err := l.Accept() + assert.NoError(t, err) + assert.Equal(t, float64(2), testutil.ToFloat64(g)) + + // Closing connections should decrement the gauge. + assert.NoError(t, c1.Close()) + assert.Equal(t, float64(1), testutil.ToFloat64(g)) + assert.NoError(t, c2.Close()) + assert.Equal(t, float64(0), testutil.ToFloat64(g)) + + // Duplicate calls to Close should not decrement. + assert.NoError(t, c1.Close()) + assert.Equal(t, float64(0), testutil.ToFloat64(g)) + + // Accept errors should not cause an increment. + fake.acceptErr = errors.New("accept") + _, err = l.Accept() + assert.Error(t, err) + assert.Equal(t, float64(0), testutil.ToFloat64(g)) + + // Close errors should still decrement. + fake.acceptErr = nil + fake.closeErr = errors.New("close") + c3, err := l.Accept() + assert.NoError(t, err) + assert.Equal(t, float64(1), testutil.ToFloat64(g)) + assert.Error(t, c3.Close()) + assert.Equal(t, float64(0), testutil.ToFloat64(g)) +} diff --git a/middleware/errorhandler.go b/middleware/errorhandler.go new file mode 100644 index 000000000..4a48c90c9 --- /dev/null +++ b/middleware/errorhandler.go @@ -0,0 +1,103 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/errorhandler.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +func copyHeaders(src, dest http.Header) { + for k, v := range src { + dest[k] = v + } +} + +// ErrorHandler lets you call an alternate http handler upon a certain response code. +// Note it will assume a 200 if the wrapped handler does not write anything +type ErrorHandler struct { + Code int + Handler http.Handler +} + +// Wrap implements Middleware +func (e ErrorHandler) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + i := newErrorInterceptor(w, e.Code) + next.ServeHTTP(i, r) + if !i.gotCode { + i.WriteHeader(http.StatusOK) + } + if i.intercepted { + e.Handler.ServeHTTP(w, r) + } + }) +} + +// errorInterceptor wraps an underlying ResponseWriter and buffers all header changes, until it knows the return code. +// It then passes everything through, unless the code matches the target code, in which case it will discard everything. +type errorInterceptor struct { + originalWriter http.ResponseWriter + targetCode int + headers http.Header + gotCode bool + intercepted bool +} + +func newErrorInterceptor(w http.ResponseWriter, code int) *errorInterceptor { + i := errorInterceptor{originalWriter: w, targetCode: code} + i.headers = make(http.Header) + copyHeaders(w.Header(), i.headers) + return &i +} + +// Unwrap method is used by http.ResponseController to get access to original http.ResponseWriter. +func (i *errorInterceptor) Unwrap() http.ResponseWriter { + return i.originalWriter +} + +// Header implements http.ResponseWriter +func (i *errorInterceptor) Header() http.Header { + return i.headers +} + +// WriteHeader implements http.ResponseWriter +func (i *errorInterceptor) WriteHeader(code int) { + if i.gotCode { + panic("errorInterceptor.WriteHeader() called twice") + } + + i.gotCode = true + if code == i.targetCode { + i.intercepted = true + } else { + copyHeaders(i.headers, i.originalWriter.Header()) + i.originalWriter.WriteHeader(code) + } +} + +// Write implements http.ResponseWriter +func (i *errorInterceptor) Write(data []byte) (int, error) { + if !i.gotCode { + i.WriteHeader(http.StatusOK) + } + if !i.intercepted { + return i.originalWriter.Write(data) + } + return len(data), nil +} + +// errorInterceptor also implements net.Hijacker, to let the downstream Handler +// hijack the connection. This is needed, for example, for working with websockets. +func (i *errorInterceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := i.originalWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("error interceptor: can't cast original ResponseWriter to Hijacker") + } + i.gotCode = true + return hj.Hijack() +} diff --git a/middleware/grpc_auth.go b/middleware/grpc_auth.go new file mode 100644 index 000000000..156ddaf10 --- /dev/null +++ b/middleware/grpc_auth.go @@ -0,0 +1,66 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/grpc_auth.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/grafana/dskit/user" +) + +// ClientUserHeaderInterceptor propagates the user ID from the context to gRPC metadata, which eventually ends up as a HTTP2 header. +func ClientUserHeaderInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx, err := user.InjectIntoGRPCRequest(ctx) + if err != nil { + return err + } + + return invoker(ctx, method, req, reply, cc, opts...) +} + +// StreamClientUserHeaderInterceptor propagates the user ID from the context to gRPC metadata, which eventually ends up as a HTTP2 header. +// For streaming gRPC requests. +func StreamClientUserHeaderInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ctx, err := user.InjectIntoGRPCRequest(ctx) + if err != nil { + return nil, err + } + + return streamer(ctx, desc, cc, method, opts...) +} + +// ServerUserHeaderInterceptor propagates the user ID from the gRPC metadata back to our context. +func ServerUserHeaderInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + _, ctx, err := user.ExtractFromGRPCRequest(ctx) + if err != nil { + return nil, err + } + + return handler(ctx, req) +} + +// StreamServerUserHeaderInterceptor propagates the user ID from the gRPC metadata back to our context. +func StreamServerUserHeaderInterceptor(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + _, ctx, err := user.ExtractFromGRPCRequest(ss.Context()) + if err != nil { + return err + } + + return handler(srv, serverStream{ + ctx: ctx, + ServerStream: ss, + }) +} + +type serverStream struct { + ctx context.Context + grpc.ServerStream +} + +func (ss serverStream) Context() context.Context { + return ss.ctx +} diff --git a/middleware/grpc_instrumentation.go b/middleware/grpc_instrumentation.go new file mode 100644 index 000000000..4a0899c25 --- /dev/null +++ b/middleware/grpc_instrumentation.go @@ -0,0 +1,141 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/grpc_instrumentation.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "context" + "io" + "strconv" + "time" + + "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + grpcUtils "github.com/grafana/dskit/grpcutil" + "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/dskit/instrument" +) + +func observe(ctx context.Context, hist *prometheus.HistogramVec, method string, err error, duration time.Duration) { + respStatus := "success" + if err != nil { + if errResp, ok := httpgrpc.HTTPResponseFromError(err); ok { + respStatus = strconv.Itoa(int(errResp.Code)) + } else if grpcUtils.IsCanceled(err) { + respStatus = "cancel" + } else { + respStatus = "error" + } + } + instrument.ObserveWithExemplar(ctx, hist.WithLabelValues(gRPC, method, respStatus, "false"), duration.Seconds()) +} + +// UnaryServerInstrumentInterceptor instruments gRPC requests for errors and latency. +func UnaryServerInstrumentInterceptor(hist *prometheus.HistogramVec) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + begin := time.Now() + resp, err := handler(ctx, req) + observe(ctx, hist, info.FullMethod, err, time.Since(begin)) + return resp, err + } +} + +// StreamServerInstrumentInterceptor instruments gRPC requests for errors and latency. +func StreamServerInstrumentInterceptor(hist *prometheus.HistogramVec) grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + begin := time.Now() + err := handler(srv, ss) + observe(ss.Context(), hist, info.FullMethod, err, time.Since(begin)) + return err + } +} + +// UnaryClientInstrumentInterceptor records duration of gRPC requests client side. +func UnaryClientInstrumentInterceptor(metric *prometheus.HistogramVec) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + start := time.Now() + err := invoker(ctx, method, req, resp, cc, opts...) + metric.WithLabelValues(method, errorCode(err)).Observe(time.Since(start).Seconds()) + return err + } +} + +// StreamClientInstrumentInterceptor records duration of streaming gRPC requests client side. +func StreamClientInstrumentInterceptor(metric *prometheus.HistogramVec) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + streamer grpc.Streamer, opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + start := time.Now() + stream, err := streamer(ctx, desc, cc, method, opts...) + return &instrumentedClientStream{ + metric: metric, + start: start, + method: method, + ClientStream: stream, + }, err + } +} + +type instrumentedClientStream struct { + metric *prometheus.HistogramVec + start time.Time + method string + grpc.ClientStream +} + +func (s *instrumentedClientStream) SendMsg(m interface{}) error { + err := s.ClientStream.SendMsg(m) + if err == nil { + return nil + } + + if err == io.EOF { + s.metric.WithLabelValues(s.method, errorCode(nil)).Observe(time.Since(s.start).Seconds()) + } else { + s.metric.WithLabelValues(s.method, errorCode(err)).Observe(time.Since(s.start).Seconds()) + } + + return err +} + +func (s *instrumentedClientStream) RecvMsg(m interface{}) error { + err := s.ClientStream.RecvMsg(m) + if err == nil { + return nil + } + + if err == io.EOF { + s.metric.WithLabelValues(s.method, errorCode(nil)).Observe(time.Since(s.start).Seconds()) + } else { + s.metric.WithLabelValues(s.method, errorCode(err)).Observe(time.Since(s.start).Seconds()) + } + + return err +} + +func (s *instrumentedClientStream) Header() (metadata.MD, error) { + md, err := s.ClientStream.Header() + if err != nil { + s.metric.WithLabelValues(s.method, errorCode(err)).Observe(time.Since(s.start).Seconds()) + } + return md, err +} + +// errorCode converts an error into an error code string. +func errorCode(err error) string { + if err == nil { + return "2xx" + } + + if errResp, ok := httpgrpc.HTTPResponseFromError(err); ok { + statusFamily := int(errResp.Code / 100) + return strconv.Itoa(statusFamily) + "xx" + } else if grpcUtils.IsCanceled(err) { + return "cancel" + } else { + return "error" + } +} diff --git a/middleware/grpc_instrumentation_test.go b/middleware/grpc_instrumentation_test.go new file mode 100644 index 000000000..19ad6865d --- /dev/null +++ b/middleware/grpc_instrumentation_test.go @@ -0,0 +1,45 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/grpc_instrumentation_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/grafana/dskit/httpgrpc" +) + +func TestErrorCode_NoError(t *testing.T) { + a := errorCode(nil) + assert.Equal(t, "2xx", a) +} + +func TestErrorCode_Any5xx(t *testing.T) { + err := httpgrpc.Errorf(http.StatusNotImplemented, "Fail") + a := errorCode(err) + assert.Equal(t, "5xx", a) +} + +func TestErrorCode_Any4xx(t *testing.T) { + err := httpgrpc.Errorf(http.StatusConflict, "Fail") + a := errorCode(err) + assert.Equal(t, "4xx", a) +} + +func TestErrorCode_Canceled(t *testing.T) { + err := status.Errorf(codes.Canceled, "Fail") + a := errorCode(err) + assert.Equal(t, "cancel", a) +} + +func TestErrorCode_Unknown(t *testing.T) { + err := status.Errorf(codes.Unknown, "Fail") + a := errorCode(err) + assert.Equal(t, "error", a) +} diff --git a/middleware/grpc_logging.go b/middleware/grpc_logging.go new file mode 100644 index 000000000..bb9c99571 --- /dev/null +++ b/middleware/grpc_logging.go @@ -0,0 +1,84 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/grpc_logging.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "context" + "errors" + "time" + + "google.golang.org/grpc" + + grpcUtils "github.com/grafana/dskit/grpcutil" + "github.com/grafana/dskit/log" + "github.com/grafana/dskit/user" +) + +const ( + gRPC = "gRPC" + errorKey = "err" +) + +// An error can implement ShouldLog() to control whether GRPCServerLog will log. +type OptionalLogging interface { + ShouldLog(ctx context.Context, duration time.Duration) bool +} + +// GRPCServerLog logs grpc requests, errors, and latency. +type GRPCServerLog struct { + Log log.Interface + // WithRequest will log the entire request rather than just the error + WithRequest bool + DisableRequestSuccessLog bool +} + +// UnaryServerInterceptor returns an interceptor that logs gRPC requests +func (s GRPCServerLog) UnaryServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + begin := time.Now() + resp, err := handler(ctx, req) + if err == nil && s.DisableRequestSuccessLog { + return resp, nil + } + var optional OptionalLogging + if errors.As(err, &optional) && !optional.ShouldLog(ctx, time.Since(begin)) { + return resp, err + } + + entry := user.LogWith(ctx, s.Log).WithFields(log.Fields{"method": info.FullMethod, "duration": time.Since(begin)}) + if err != nil { + if s.WithRequest { + entry = entry.WithField("request", req) + } + if grpcUtils.IsCanceled(err) { + entry.WithField(errorKey, err).Debugln(gRPC) + } else { + entry.WithField(errorKey, err).Warnln(gRPC) + } + } else { + entry.Debugf("%s (success)", gRPC) + } + return resp, err +} + +// StreamServerInterceptor returns an interceptor that logs gRPC requests +func (s GRPCServerLog) StreamServerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + begin := time.Now() + err := handler(srv, ss) + if err == nil && s.DisableRequestSuccessLog { + return nil + } + + entry := user.LogWith(ss.Context(), s.Log).WithFields(log.Fields{"method": info.FullMethod, "duration": time.Since(begin)}) + if err != nil { + if grpcUtils.IsCanceled(err) { + entry.WithField(errorKey, err).Debugln(gRPC) + } else { + entry.WithField(errorKey, err).Warnln(gRPC) + } + } else { + entry.Debugf("%s (success)", gRPC) + } + return err +} diff --git a/middleware/grpc_logging_test.go b/middleware/grpc_logging_test.go new file mode 100644 index 000000000..16be43fe7 --- /dev/null +++ b/middleware/grpc_logging_test.go @@ -0,0 +1,85 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/grpc_logging_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "bytes" + "context" + "errors" + "testing" + "time" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + dskit_log "github.com/grafana/dskit/log" +) + +func BenchmarkGRPCServerLog_UnaryServerInterceptor_NoError(b *testing.B) { + logger := dskit_log.GoKit(level.NewFilter(log.NewNopLogger(), level.AllowError())) + l := GRPCServerLog{Log: logger, WithRequest: false, DisableRequestSuccessLog: true} + ctx := context.Background() + info := &grpc.UnaryServerInfo{FullMethod: "Test"} + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + } + + b.ResetTimer() + b.ReportAllocs() + + for n := 0; n < b.N; n++ { + _, _ = l.UnaryServerInterceptor(ctx, nil, info, handler) + } +} + +type doNotLogError struct{ Err error } + +func (i doNotLogError) Error() string { return i.Err.Error() } +func (i doNotLogError) Unwrap() error { return i.Err } +func (i doNotLogError) ShouldLog(_ context.Context, _ time.Duration) bool { return false } + +func TestGrpcLogging(t *testing.T) { + ctx := context.Background() + info := &grpc.UnaryServerInfo{FullMethod: "Test"} + for _, tc := range []struct { + err error + logContains []string + }{{ + err: context.Canceled, + logContains: []string{"level=debug", "context canceled"}, + }, { + err: errors.New("yolo"), + logContains: []string{"level=warn", "err=yolo"}, + }, { + err: nil, + logContains: []string{"level=debug", "method=Test"}, + }, { + err: doNotLogError{Err: errors.New("yolo")}, + logContains: nil, + }} { + t.Run("", func(t *testing.T) { + buf := bytes.NewBuffer(nil) + logger := dskit_log.GoKit(log.NewLogfmtLogger(buf)) + l := GRPCServerLog{Log: logger, WithRequest: true, DisableRequestSuccessLog: false} + + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, tc.err + } + + _, err := l.UnaryServerInterceptor(ctx, nil, info, handler) + require.ErrorIs(t, tc.err, err) + + if len(tc.logContains) == 0 { + require.Empty(t, buf) + } + for _, content := range tc.logContains { + require.Contains(t, buf.String(), content) + } + }) + } +} diff --git a/middleware/grpc_stats.go b/middleware/grpc_stats.go new file mode 100644 index 000000000..3d29d9baa --- /dev/null +++ b/middleware/grpc_stats.go @@ -0,0 +1,73 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/grpc_stats.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "context" + + "github.com/prometheus/client_golang/prometheus" + "google.golang.org/grpc/stats" +) + +// NewStatsHandler creates handler that can be added to gRPC server options to track received and sent message sizes. +func NewStatsHandler(receivedPayloadSize, sentPayloadSize *prometheus.HistogramVec, inflightRequests *prometheus.GaugeVec) stats.Handler { + return &grpcStatsHandler{ + receivedPayloadSize: receivedPayloadSize, + sentPayloadSize: sentPayloadSize, + inflightRequests: inflightRequests, + } +} + +type grpcStatsHandler struct { + receivedPayloadSize *prometheus.HistogramVec + sentPayloadSize *prometheus.HistogramVec + inflightRequests *prometheus.GaugeVec +} + +// Custom type to hide it from other packages. +type contextKey int + +const ( + contextKeyMethodName contextKey = 1 +) + +func (g *grpcStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + return context.WithValue(ctx, contextKeyMethodName, info.FullMethodName) +} + +func (g *grpcStatsHandler) HandleRPC(ctx context.Context, rpcStats stats.RPCStats) { + // We use full method name from context, because not all RPCStats structs have it. + fullMethodName, ok := ctx.Value(contextKeyMethodName).(string) + if !ok { + return + } + + switch s := rpcStats.(type) { + case *stats.Begin: + g.inflightRequests.WithLabelValues(gRPC, fullMethodName).Inc() + case *stats.End: + g.inflightRequests.WithLabelValues(gRPC, fullMethodName).Dec() + case *stats.InHeader: + // Ignore incoming headers. + case *stats.InPayload: + g.receivedPayloadSize.WithLabelValues(gRPC, fullMethodName).Observe(float64(s.WireLength)) + case *stats.InTrailer: + // Ignore incoming trailers. + case *stats.OutHeader: + // Ignore outgoing headers. + case *stats.OutPayload: + g.sentPayloadSize.WithLabelValues(gRPC, fullMethodName).Observe(float64(s.WireLength)) + case *stats.OutTrailer: + // Ignore outgoing trailers. OutTrailer doesn't have valid WireLength (there is a deprecated field, always set to 0). + } +} + +func (g *grpcStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + return ctx +} + +func (g *grpcStatsHandler) HandleConn(_ context.Context, _ stats.ConnStats) { + // Not interested. +} diff --git a/middleware/grpc_stats_test.go b/middleware/grpc_stats_test.go new file mode 100644 index 000000000..8c21f7713 --- /dev/null +++ b/middleware/grpc_stats_test.go @@ -0,0 +1,274 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/grpc_stats_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "bytes" + "context" + "crypto/rand" + "net" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/grafana/dskit/middleware/middleware_test" +) + +func TestGrpcStats(t *testing.T) { + reg := prometheus.NewRegistry() + + received := promauto.With(reg).NewHistogramVec(prometheus.HistogramOpts{ + Name: "received_payload_bytes", + Help: "Size of received gRPC messages", + Buckets: BodySizeBuckets, + }, []string{"method", "route"}) + + sent := promauto.With(reg).NewHistogramVec(prometheus.HistogramOpts{ + Name: "sent_payload_bytes", + Help: "Size of sent gRPC", + Buckets: BodySizeBuckets, + }, []string{"method", "route"}) + + inflightRequests := promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Name: "inflight_requests", + Help: "Current number of inflight requests.", + }, []string{"method", "route"}) + + stats := NewStatsHandler(received, sent, inflightRequests) + + serv := grpc.NewServer(grpc.StatsHandler(stats), grpc.MaxRecvMsgSize(10e6)) + defer serv.GracefulStop() + + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + go func() { + require.NoError(t, serv.Serve(listener)) + }() + + grpc_health_v1.RegisterHealthServer(serv, health.NewServer()) + + closed := false + conn, err := grpc.Dial(listener.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer func() { + if !closed { + require.NoError(t, conn.Close()) + } + }() + + hc := grpc_health_v1.NewHealthClient(conn) + + // First request (empty). + resp, err := hc.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{}) + require.NoError(t, err) + require.Equal(t, grpc_health_v1.HealthCheckResponse_SERVING, resp.Status) + + // Second request, with large service name. This returns error, which doesn't count as "payload". + _, err = hc.Check(context.Background(), &grpc_health_v1.HealthCheckRequest{ + Service: generateString(8 * 1024 * 1024), + }) + require.EqualError(t, err, "rpc error: code = NotFound desc = unknown service") + + err = testutil.GatherAndCompare(reg, bytes.NewBufferString(` + # HELP received_payload_bytes Size of received gRPC messages + # TYPE received_payload_bytes histogram + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="1.048576e+06"} 1 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="2.62144e+06"} 1 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="5.24288e+06"} 1 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="1.048576e+07"} 2 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="2.62144e+07"} 2 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="5.24288e+07"} 2 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="1.048576e+08"} 2 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="2.62144e+08"} 2 + received_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="+Inf"} 2 + received_payload_bytes_sum{method="gRPC", route="/grpc.health.v1.Health/Check"} 8.388623e+06 + received_payload_bytes_count{method="gRPC", route="/grpc.health.v1.Health/Check"} 2 + + # HELP sent_payload_bytes Size of sent gRPC + # TYPE sent_payload_bytes histogram + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="1.048576e+06"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="2.62144e+06"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="5.24288e+06"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="1.048576e+07"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="2.62144e+07"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="5.24288e+07"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="1.048576e+08"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="2.62144e+08"} 1 + sent_payload_bytes_bucket{method="gRPC", route="/grpc.health.v1.Health/Check",le="+Inf"} 1 + sent_payload_bytes_sum{method="gRPC", route="/grpc.health.v1.Health/Check"} 7 + sent_payload_bytes_count{method="gRPC", route="/grpc.health.v1.Health/Check"} 1 + `), "received_payload_bytes", "sent_payload_bytes") + require.NoError(t, err) + + closed = true + require.NoError(t, conn.Close()) +} + +func TestGrpcStatsStreaming(t *testing.T) { + reg := prometheus.NewRegistry() + + received := promauto.With(reg).NewHistogramVec(prometheus.HistogramOpts{ + Name: "received_payload_bytes", + Help: "Size of received gRPC messages", + Buckets: BodySizeBuckets, + }, []string{"method", "route"}) + + sent := promauto.With(reg).NewHistogramVec(prometheus.HistogramOpts{ + Name: "sent_payload_bytes", + Help: "Size of sent gRPC", + Buckets: BodySizeBuckets, + }, []string{"method", "route"}) + + inflightRequests := promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Name: "inflight_requests", + Help: "Current number of inflight requests.", + }, []string{"method", "route"}) + + stats := NewStatsHandler(received, sent, inflightRequests) + + serv := grpc.NewServer(grpc.StatsHandler(stats), grpc.MaxSendMsgSize(10e6), grpc.MaxRecvMsgSize(10e6)) + defer serv.GracefulStop() + + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + go func() { + require.NoError(t, serv.Serve(listener)) + }() + + middleware_test.RegisterEchoServerServer(serv, &halfEcho{log: t.Log}) + + conn, err := grpc.Dial(listener.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(10e6), grpc.MaxCallSendMsgSize(10e6))) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close()) + }() + + fc := middleware_test.NewEchoServerClient(conn) + + s, err := fc.Process(context.Background()) + require.NoError(t, err) + + for ix := 0; ix < 5; ix++ { + msg := &middleware_test.Msg{ + Body: []byte(generateString((ix + 1) * 1024 * 1024)), + } + + t.Log("Client Sending", msg.Size()) + err = s.Send(msg) + require.NoError(t, err) + + _, err := s.Recv() + require.NoError(t, err) + + err = testutil.GatherAndCompare(reg, bytes.NewBufferString(` + # HELP inflight_requests Current number of inflight requests. + # TYPE inflight_requests gauge + inflight_requests{method="gRPC", route="/middleware.EchoServer/Process"} 1 + `), "inflight_requests") + require.NoError(t, err) + } + require.NoError(t, s.CloseSend()) + + // Wait for inflight_requests to go to 0. + timeout := 1 * time.Second + sleep := timeout / 10 + + for endTime := time.Now().Add(timeout); time.Now().Before(endTime); { + err = testutil.GatherAndCompare(reg, bytes.NewBufferString(` + # HELP inflight_requests Current number of inflight requests. + # TYPE inflight_requests gauge + inflight_requests{method="gRPC", route="/middleware.EchoServer/Process"} 0 + `), "inflight_requests") + if err == nil { + break + } + time.Sleep(sleep) + } + require.NoError(t, err) + + err = testutil.GatherAndCompare(reg, bytes.NewBufferString(` + # HELP received_payload_bytes Size of received gRPC messages + # TYPE received_payload_bytes histogram + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="1.048576e+06"} 0 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="2.62144e+06"} 2 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="5.24288e+06"} 4 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="1.048576e+07"} 5 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="2.62144e+07"} 5 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="5.24288e+07"} 5 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="1.048576e+08"} 5 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="2.62144e+08"} 5 + received_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="+Inf"} 5 + received_payload_bytes_sum{method="gRPC",route="/middleware.EchoServer/Process"} 1.5728689e+07 + received_payload_bytes_count{method="gRPC",route="/middleware.EchoServer/Process"} 5 + + # HELP sent_payload_bytes Size of sent gRPC + # TYPE sent_payload_bytes histogram + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="1.048576e+06"} 1 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="2.62144e+06"} 4 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="5.24288e+06"} 5 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="1.048576e+07"} 5 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="2.62144e+07"} 5 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="5.24288e+07"} 5 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="1.048576e+08"} 5 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="2.62144e+08"} 5 + sent_payload_bytes_bucket{method="gRPC",route="/middleware.EchoServer/Process",le="+Inf"} 5 + sent_payload_bytes_sum{method="gRPC",route="/middleware.EchoServer/Process"} 7.864367e+06 + sent_payload_bytes_count{method="gRPC",route="/middleware.EchoServer/Process"} 5 + `), "received_payload_bytes", "sent_payload_bytes") + + require.NoError(t, err) +} + +type halfEcho struct { + log func(args ...interface{}) +} + +func (f halfEcho) Process(server middleware_test.EchoServer_ProcessServer) error { + for { + msg, err := server.Recv() + if err != nil { + return err + } + + // Half the body + msg.Body = msg.Body[:len(msg.Body)/2] + + f.log("Server Sending", msg.Size()) + err = server.Send(msg) + if err != nil { + return err + } + } +} + +func generateString(size int) string { + // Use random bytes, to avoid compression. + buf := make([]byte, size) + _, err := rand.Read(buf) + if err != nil { + // Should not happen. + panic(err) + } + + // To avoid invalid UTF-8 sequences (which protobuf complains about), we cleanup the data a bit. + for ix, b := range buf { + if b < ' ' { + b += ' ' + } + b = b & 0x7f + buf[ix] = b + } + return string(buf) +} diff --git a/middleware/header_adder.go b/middleware/header_adder.go new file mode 100644 index 000000000..ffd5cc8db --- /dev/null +++ b/middleware/header_adder.go @@ -0,0 +1,29 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/header_adder.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "net/http" +) + +// HeaderAdder adds headers to responses +type HeaderAdder struct { + http.Header +} + +// Wrap implements Middleware +func (h HeaderAdder) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Do it in pre-order since headers need to be added before + // writing the response body + dst := w.Header() + for k, vv := range h.Header { + for _, v := range vv { + dst.Add(k, v) + } + } + next.ServeHTTP(w, r) + }) +} diff --git a/middleware/http_auth.go b/middleware/http_auth.go new file mode 100644 index 000000000..2b576a929 --- /dev/null +++ b/middleware/http_auth.go @@ -0,0 +1,23 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/http_auth.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "net/http" + + "github.com/grafana/dskit/user" +) + +// AuthenticateUser propagates the user ID from HTTP headers back to the request's context. +var AuthenticateUser = Func(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, ctx, err := user.ExtractOrgIDFromHTTPRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r.WithContext(ctx)) + }) +}) diff --git a/middleware/http_tracing.go b/middleware/http_tracing.go new file mode 100644 index 000000000..e36bf436d --- /dev/null +++ b/middleware/http_tracing.go @@ -0,0 +1,52 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/http_tracing.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "fmt" + "net/http" + + "github.com/opentracing-contrib/go-stdlib/nethttp" + "github.com/opentracing/opentracing-go" +) + +// Dummy dependency to enforce that we have a nethttp version newer +// than the one which implements Websockets. (No semver on nethttp) +var _ = nethttp.MWURLTagFunc + +// Tracer is a middleware which traces incoming requests. +type Tracer struct { + RouteMatcher RouteMatcher + SourceIPs *SourceIPExtractor +} + +// Wrap implements Interface +func (t Tracer) Wrap(next http.Handler) http.Handler { + options := []nethttp.MWOption{ + nethttp.OperationNameFunc(func(r *http.Request) string { + op := getRouteName(t.RouteMatcher, r) + if op == "" { + return "HTTP " + r.Method + } + + return fmt.Sprintf("HTTP %s - %s", r.Method, op) + }), + nethttp.MWSpanObserver(func(sp opentracing.Span, r *http.Request) { + // add a tag with the client's user agent to the span + userAgent := r.Header.Get("User-Agent") + if userAgent != "" { + sp.SetTag("http.user_agent", userAgent) + } + + // add a tag with the client's sourceIPs to the span, if a + // SourceIPExtractor is given. + if t.SourceIPs != nil { + sp.SetTag("sourceIPs", t.SourceIPs.Get(r)) + } + }), + } + + return nethttp.Middleware(opentracing.GlobalTracer(), next, options...) +} diff --git a/middleware/instrument.go b/middleware/instrument.go new file mode 100644 index 000000000..e5ae9c53c --- /dev/null +++ b/middleware/instrument.go @@ -0,0 +1,164 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/instrument.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "io" + "net/http" + "regexp" + "strconv" + "strings" + + "github.com/felixge/httpsnoop" + "github.com/gorilla/mux" + "github.com/prometheus/client_golang/prometheus" + + "github.com/grafana/dskit/instrument" +) + +const mb = 1024 * 1024 + +// BodySizeBuckets defines buckets for request/response body sizes. +var BodySizeBuckets = []float64{1 * mb, 2.5 * mb, 5 * mb, 10 * mb, 25 * mb, 50 * mb, 100 * mb, 250 * mb} + +// RouteMatcher matches routes +type RouteMatcher interface { + Match(*http.Request, *mux.RouteMatch) bool +} + +// Instrument is a Middleware which records timings for every HTTP request +type Instrument struct { + RouteMatcher RouteMatcher + Duration *prometheus.HistogramVec + RequestBodySize *prometheus.HistogramVec + ResponseBodySize *prometheus.HistogramVec + InflightRequests *prometheus.GaugeVec +} + +// IsWSHandshakeRequest returns true if the given request is a websocket handshake request. +func IsWSHandshakeRequest(req *http.Request) bool { + if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" { + // Connection header values can be of form "foo, bar, ..." + parts := strings.Split(strings.ToLower(req.Header.Get("Connection")), ",") + for _, part := range parts { + if strings.TrimSpace(part) == "upgrade" { + return true + } + } + } + return false +} + +// Wrap implements middleware.Interface +func (i Instrument) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route := i.getRouteName(r) + inflight := i.InflightRequests.WithLabelValues(r.Method, route) + inflight.Inc() + defer inflight.Dec() + + origBody := r.Body + defer func() { + // No need to leak our Body wrapper beyond the scope of this handler. + r.Body = origBody + }() + + rBody := &reqBody{b: origBody} + r.Body = rBody + + isWS := strconv.FormatBool(IsWSHandshakeRequest(r)) + + respMetrics := httpsnoop.CaptureMetricsFn(w, func(ww http.ResponseWriter) { + next.ServeHTTP(ww, r) + }) + + i.RequestBodySize.WithLabelValues(r.Method, route).Observe(float64(rBody.read)) + i.ResponseBodySize.WithLabelValues(r.Method, route).Observe(float64(respMetrics.Written)) + + instrument.ObserveWithExemplar(r.Context(), i.Duration.WithLabelValues(r.Method, route, strconv.Itoa(respMetrics.Code), isWS), respMetrics.Duration.Seconds()) + }) +} + +// Return a name identifier for ths request. There are three options: +// 1. The request matches a gorilla mux route, with a name. Use that. +// 2. The request matches an unamed gorilla mux router. Munge the path +// template such that templates like '/api/{org}/foo' come out as +// 'api_org_foo'. +// 3. The request doesn't match a mux route. Return "other" +// +// We do all this as we do not wish to emit high cardinality labels to +// prometheus. +func (i Instrument) getRouteName(r *http.Request) string { + route := getRouteName(i.RouteMatcher, r) + if route == "" { + route = "other" + } + + return route +} + +func getRouteName(routeMatcher RouteMatcher, r *http.Request) string { + var routeMatch mux.RouteMatch + if routeMatcher == nil || !routeMatcher.Match(r, &routeMatch) { + return "" + } + + if routeMatch.MatchErr == mux.ErrNotFound { + return "notfound" + } + + if routeMatch.Route == nil { + return "" + } + + if name := routeMatch.Route.GetName(); name != "" { + return name + } + + tmpl, err := routeMatch.Route.GetPathTemplate() + if err == nil { + return MakeLabelValue(tmpl) + } + + return "" +} + +var invalidChars = regexp.MustCompile(`[^a-zA-Z0-9]+`) + +// MakeLabelValue converts a Gorilla mux path to a string suitable for use in +// a Prometheus label value. +func MakeLabelValue(path string) string { + // Convert non-alnums to underscores. + result := invalidChars.ReplaceAllString(path, "_") + + // Trim leading and trailing underscores. + result = strings.Trim(result, "_") + + // Make it all lowercase + result = strings.ToLower(result) + + // Special case. + if result == "" { + result = "root" + } + return result +} + +type reqBody struct { + b io.ReadCloser + read int64 +} + +func (w *reqBody) Read(p []byte) (int, error) { + n, err := w.b.Read(p) + if n > 0 { + w.read += int64(n) + } + return n, err +} + +func (w *reqBody) Close() error { + return w.b.Close() +} diff --git a/middleware/instrument_test.go b/middleware/instrument_test.go new file mode 100644 index 000000000..23ab4e5b0 --- /dev/null +++ b/middleware/instrument_test.go @@ -0,0 +1,34 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/instrument_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware_test + +import ( + "testing" + + "github.com/grafana/dskit/middleware" +) + +func TestMakeLabelValue(t *testing.T) { + for input, want := range map[string]string{ + "/": "root", // special case + "//": "root", // unintended consequence of special case + "a": "a", + "/foo": "foo", + "foo/": "foo", + "/foo/": "foo", + "/foo/bar": "foo_bar", + "foo/bar/": "foo_bar", + "/foo/bar/": "foo_bar", + "/foo/{orgName}/Bar": "foo_orgname_bar", + "/foo/{org_name}/Bar": "foo_org_name_bar", + "/foo/{org__name}/Bar": "foo_org_name_bar", + "/foo/{org___name}/_Bar": "foo_org_name_bar", + "/foo.bar/baz.qux/": "foo_bar_baz_qux", + } { + if have := middleware.MakeLabelValue(input); want != have { + t.Errorf("%q: want %q, have %q", input, want, have) + } + } +} diff --git a/middleware/logging.go b/middleware/logging.go new file mode 100644 index 000000000..ca13f330e --- /dev/null +++ b/middleware/logging.go @@ -0,0 +1,150 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/logging.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "bytes" + "context" + "errors" + "net/http" + "time" + + "github.com/grafana/dskit/log" + "github.com/grafana/dskit/tracing" + "github.com/grafana/dskit/user" +) + +// Log middleware logs http requests +type Log struct { + Log log.Interface + DisableRequestSuccessLog bool + LogRequestHeaders bool // LogRequestHeaders true -> dump http headers at debug log level + LogRequestAtInfoLevel bool // LogRequestAtInfoLevel true -> log requests at info log level + SourceIPs *SourceIPExtractor + HTTPHeadersToExclude map[string]bool +} + +var defaultExcludedHeaders = map[string]bool{ + "Cookie": true, + "X-Csrf-Token": true, + "Authorization": true, +} + +func NewLogMiddleware(log log.Interface, logRequestHeaders bool, logRequestAtInfoLevel bool, sourceIPs *SourceIPExtractor, headersList []string) Log { + httpHeadersToExclude := map[string]bool{} + for header := range defaultExcludedHeaders { + httpHeadersToExclude[header] = true + } + for _, header := range headersList { + httpHeadersToExclude[header] = true + } + + return Log{ + Log: log, + LogRequestHeaders: logRequestHeaders, + LogRequestAtInfoLevel: logRequestAtInfoLevel, + SourceIPs: sourceIPs, + HTTPHeadersToExclude: httpHeadersToExclude, + } +} + +// logWithRequest information from the request and context as fields. +func (l Log) logWithRequest(r *http.Request) log.Interface { + localLog := l.Log + traceID, ok := tracing.ExtractTraceID(r.Context()) + if ok { + localLog = localLog.WithField("traceID", traceID) + } + + if l.SourceIPs != nil { + ips := l.SourceIPs.Get(r) + if ips != "" { + localLog = localLog.WithField("sourceIPs", ips) + } + } + + return user.LogWith(r.Context(), localLog) +} + +// Wrap implements Middleware +func (l Log) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + begin := time.Now() + uri := r.RequestURI // capture the URI before running next, as it may get rewritten + requestLog := l.logWithRequest(r) + // Log headers before running 'next' in case other interceptors change the data. + headers, err := dumpRequest(r, l.HTTPHeadersToExclude) + if err != nil { + headers = nil + requestLog.Errorf("Could not dump request headers: %v", err) + } + var buf bytes.Buffer + wrapped := newBadResponseLoggingWriter(w, &buf) + next.ServeHTTP(wrapped, r) + + statusCode, writeErr := wrapped.getStatusCode(), wrapped.getWriteError() + + if writeErr != nil { + if errors.Is(writeErr, context.Canceled) { + if l.LogRequestAtInfoLevel { + requestLog.Infof("%s %s %s, request cancelled: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers) + } else { + requestLog.Debugf("%s %s %s, request cancelled: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers) + } + } else { + requestLog.Warnf("%s %s %s, error: %s ws: %v; %s", r.Method, uri, time.Since(begin), writeErr, IsWSHandshakeRequest(r), headers) + } + + return + } + + switch { + // success and shouldn't log successful requests. + case statusCode >= 200 && statusCode < 300 && l.DisableRequestSuccessLog: + return + + case 100 <= statusCode && statusCode < 500 || statusCode == http.StatusBadGateway || statusCode == http.StatusServiceUnavailable: + if l.LogRequestAtInfoLevel { + requestLog.Infof("%s %s (%d) %s", r.Method, uri, statusCode, time.Since(begin)) + + if l.LogRequestHeaders && headers != nil { + requestLog.Infof("ws: %v; %s", IsWSHandshakeRequest(r), string(headers)) + } + return + } + + requestLog.Debugf("%s %s (%d) %s", r.Method, uri, statusCode, time.Since(begin)) + if l.LogRequestHeaders && headers != nil { + requestLog.Debugf("ws: %v; %s", IsWSHandshakeRequest(r), string(headers)) + } + default: + requestLog.Warnf("%s %s (%d) %s Response: %q ws: %v; %s", + r.Method, uri, statusCode, time.Since(begin), buf.Bytes(), IsWSHandshakeRequest(r), headers) + } + }) +} + +// Logging middleware logs each HTTP request method, path, response code and +// duration for all HTTP requests. +var Logging = Log{ + Log: log.Global(), +} + +func dumpRequest(req *http.Request, httpHeadersToExclude map[string]bool) ([]byte, error) { + var b bytes.Buffer + + // In case users initialize the Log middleware using the exported struct, skip the default headers anyway + if len(httpHeadersToExclude) == 0 { + httpHeadersToExclude = defaultExcludedHeaders + } + // Exclude some headers for security, or just that we don't need them when debugging + err := req.Header.WriteSubset(&b, httpHeadersToExclude) + if err != nil { + return nil, err + } + + ret := bytes.Replace(b.Bytes(), []byte("\r\n"), []byte("; "), -1) + return ret, nil +} diff --git a/middleware/logging_test.go b/middleware/logging_test.go new file mode 100644 index 000000000..5a7ae602c --- /dev/null +++ b/middleware/logging_test.go @@ -0,0 +1,229 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/logging_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "github.com/grafana/dskit/log" +) + +func TestBadWriteLogging(t *testing.T) { + for _, tc := range []struct { + err error + logContains []string + }{{ + err: context.Canceled, + logContains: []string{"debug", "request cancelled: context canceled"}, + }, { + err: errors.New("yolo"), + logContains: []string{"warning", "error: yolo"}, + }, { + err: nil, + logContains: []string{"debug", "GET http://example.com/foo (200)"}, + }} { + buf := bytes.NewBuffer(nil) + logrusLogger := logrus.New() + logrusLogger.Out = buf + logrusLogger.Level = logrus.DebugLevel + + loggingMiddleware := Log{ + Log: log.Logrus(logrusLogger), + } + handler := func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "Hello World!") + } + loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + recorder := httptest.NewRecorder() + + w := errorWriter{ + err: tc.err, + w: recorder, + } + loggingHandler.ServeHTTP(w, req) + + for _, content := range tc.logContains { + require.True(t, bytes.Contains(buf.Bytes(), []byte(content))) + } + } +} + +func TestDisabledSuccessfulRequestsLogging(t *testing.T) { + for _, tc := range []struct { + err error + disableLog bool + logContains string + }{ + { + err: nil, + disableLog: false, + }, { + err: nil, + disableLog: true, + logContains: "", + }, + } { + buf := bytes.NewBuffer(nil) + logrusLogger := logrus.New() + logrusLogger.Out = buf + logrusLogger.Level = logrus.DebugLevel + + loggingMiddleware := Log{ + Log: log.Logrus(logrusLogger), + DisableRequestSuccessLog: tc.disableLog, + } + + handler := func(w http.ResponseWriter, r *http.Request) { + _, err := io.WriteString(w, "Hello World!") + require.NoError(t, err) //nolint:errcheck + } + loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + recorder := httptest.NewRecorder() + + w := errorWriter{ + err: tc.err, + w: recorder, + } + loggingHandler.ServeHTTP(w, req) + content := buf.String() + + if !tc.disableLog { + require.Contains(t, content, "GET http://example.com/foo (200)") + } else { + require.NotContains(t, content, "(200)") + require.Empty(t, content) + } + } +} + +func TestLoggingRequestsAtInfoLevel(t *testing.T) { + for _, tc := range []struct { + err error + logContains []string + }{{ + err: context.Canceled, + logContains: []string{"info", "request cancelled: context canceled"}, + }, { + err: nil, + logContains: []string{"info", "GET http://example.com/foo (200)"}, + }} { + buf := bytes.NewBuffer(nil) + logrusLogger := logrus.New() + logrusLogger.Out = buf + logrusLogger.Level = logrus.DebugLevel + + loggingMiddleware := Log{ + Log: log.Logrus(logrusLogger), + LogRequestAtInfoLevel: true, + } + handler := func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "Hello World!") + } + loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + recorder := httptest.NewRecorder() + + w := errorWriter{ + err: tc.err, + w: recorder, + } + loggingHandler.ServeHTTP(w, req) + + for _, content := range tc.logContains { + require.True(t, bytes.Contains(buf.Bytes(), []byte(content))) + } + } +} + +func TestLoggingRequestWithExcludedHeaders(t *testing.T) { + defaultHeaders := []string{"Authorization", "Cookie", "X-Csrf-Token"} + for _, tc := range []struct { + name string + setHeaderList []string + excludeHeaderList []string + mustNotContain []string + }{ + { + name: "Default excluded headers are excluded", + setHeaderList: defaultHeaders, + mustNotContain: defaultHeaders, + }, + { + name: "Extra configured header is also excluded", + setHeaderList: append(defaultHeaders, "X-Secret-Header"), + excludeHeaderList: []string{"X-Secret-Header"}, + mustNotContain: append(defaultHeaders, "X-Secret-Header"), + }, + { + name: "Multiple extra configured headers are also excluded", + setHeaderList: append(defaultHeaders, "X-Secret-Header", "X-Secret-Header-2"), + excludeHeaderList: []string{"X-Secret-Header", "X-Secret-Header-2"}, + mustNotContain: append(defaultHeaders, "X-Secret-Header", "X-Secret-Header-2"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + buf := bytes.NewBuffer(nil) + logrusLogger := logrus.New() + logrusLogger.Out = buf + logrusLogger.Level = logrus.DebugLevel + + loggingMiddleware := NewLogMiddleware(log.Logrus(logrusLogger), true, false, nil, tc.excludeHeaderList) + + handler := func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "Hello world!") + } + loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + for _, header := range tc.setHeaderList { + req.Header.Set(header, header) + } + + recorder := httptest.NewRecorder() + loggingHandler.ServeHTTP(recorder, req) + + output := buf.String() + for _, header := range tc.mustNotContain { + require.NotContains(t, output, header) + } + }) + } +} + +type errorWriter struct { + err error + + w http.ResponseWriter +} + +func (e errorWriter) Header() http.Header { + return e.w.Header() +} + +func (e errorWriter) WriteHeader(statusCode int) { + e.w.WriteHeader(statusCode) +} + +func (e errorWriter) Write(b []byte) (int, error) { + if e.err != nil { + return 0, e.err + } + + return e.w.Write(b) +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 000000000..79720b333 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,37 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/middleware.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "net/http" +) + +// Interface is the shared contract for all middlesware, and allows middlesware +// to wrap handlers. +type Interface interface { + Wrap(http.Handler) http.Handler +} + +// Func is to Interface as http.HandlerFunc is to http.Handler +type Func func(http.Handler) http.Handler + +// Wrap implements Interface +func (m Func) Wrap(next http.Handler) http.Handler { + return m(next) +} + +// Identity is an Interface which doesn't do anything. +var Identity Interface = Func(func(h http.Handler) http.Handler { return h }) + +// Merge produces a middleware that applies multiple middlesware in turn; +// ie Merge(f,g,h).Wrap(handler) == f.Wrap(g.Wrap(h.Wrap(handler))) +func Merge(middlesware ...Interface) Interface { + return Func(func(next http.Handler) http.Handler { + for i := len(middlesware) - 1; i >= 0; i-- { + next = middlesware[i].Wrap(next) + } + return next + }) +} diff --git a/middleware/middleware_test/echo_server.pb.go b/middleware/middleware_test/echo_server.pb.go new file mode 100644 index 000000000..ff6509352 --- /dev/null +++ b/middleware/middleware_test/echo_server.pb.go @@ -0,0 +1,526 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: echo_server.proto + +package middleware_test + +import ( + bytes "bytes" + context "context" + fmt "fmt" + proto "github.com/gogo/protobuf/proto" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + io "io" + math "math" + math_bits "math/bits" + reflect "reflect" + strings "strings" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type Msg struct { + Body []byte `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"` +} + +func (m *Msg) Reset() { *m = Msg{} } +func (*Msg) ProtoMessage() {} +func (*Msg) Descriptor() ([]byte, []int) { + return fileDescriptor_f24ecbb5972b7a26, []int{0} +} +func (m *Msg) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Msg) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Msg.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Msg) XXX_Merge(src proto.Message) { + xxx_messageInfo_Msg.Merge(m, src) +} +func (m *Msg) XXX_Size() int { + return m.Size() +} +func (m *Msg) XXX_DiscardUnknown() { + xxx_messageInfo_Msg.DiscardUnknown(m) +} + +var xxx_messageInfo_Msg proto.InternalMessageInfo + +func (m *Msg) GetBody() []byte { + if m != nil { + return m.Body + } + return nil +} + +func init() { + proto.RegisterType((*Msg)(nil), "middleware.Msg") +} + +func init() { proto.RegisterFile("echo_server.proto", fileDescriptor_f24ecbb5972b7a26) } + +var fileDescriptor_f24ecbb5972b7a26 = []byte{ + // 190 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4c, 0x4d, 0xce, 0xc8, + 0x8f, 0x2f, 0x4e, 0x2d, 0x2a, 0x4b, 0x2d, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0xca, + 0xcd, 0x4c, 0x49, 0xc9, 0x49, 0x2d, 0x4f, 0x2c, 0x4a, 0x55, 0x92, 0xe4, 0x62, 0xf6, 0x2d, 0x4e, + 0x17, 0x12, 0xe2, 0x62, 0x49, 0xca, 0x4f, 0xa9, 0x94, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x02, + 0xb3, 0x8d, 0xec, 0xb9, 0xb8, 0x5c, 0x93, 0x33, 0xf2, 0x83, 0xc1, 0x5a, 0x85, 0x0c, 0xb9, 0xd8, + 0x03, 0x8a, 0xf2, 0x93, 0x53, 0x8b, 0x8b, 0x85, 0xf8, 0xf5, 0x10, 0x06, 0xe8, 0xf9, 0x16, 0xa7, + 0x4b, 0xa1, 0x0b, 0x28, 0x31, 0x68, 0x30, 0x1a, 0x30, 0x3a, 0xb9, 0x5e, 0x78, 0x28, 0xc7, 0x70, + 0xe3, 0xa1, 0x1c, 0xc3, 0x87, 0x87, 0x72, 0x8c, 0x0d, 0x8f, 0xe4, 0x18, 0x57, 0x3c, 0x92, 0x63, + 0x3c, 0xf1, 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6, 0x07, 0x8f, 0xe4, 0x18, 0x5f, 0x3c, 0x92, + 0x63, 0xf8, 0xf0, 0x48, 0x8e, 0x71, 0xc2, 0x63, 0x39, 0x86, 0x0b, 0x8f, 0xe5, 0x18, 0x6e, 0x3c, + 0x96, 0x63, 0x88, 0xe2, 0x47, 0x98, 0x15, 0x5f, 0x92, 0x5a, 0x5c, 0x92, 0xc4, 0x06, 0x76, 0xb5, + 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0xce, 0xf1, 0x2d, 0xa4, 0xca, 0x00, 0x00, 0x00, +} + +func (this *Msg) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*Msg) + if !ok { + that2, ok := that.(Msg) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if !bytes.Equal(this.Body, that1.Body) { + return false + } + return true +} +func (this *Msg) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&middleware_test.Msg{") + s = append(s, "Body: "+fmt.Sprintf("%#v", this.Body)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func valueToGoStringEchoServer(v interface{}, typ string) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// EchoServerClient is the client API for EchoServer service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type EchoServerClient interface { + Process(ctx context.Context, opts ...grpc.CallOption) (EchoServer_ProcessClient, error) +} + +type echoServerClient struct { + cc *grpc.ClientConn +} + +func NewEchoServerClient(cc *grpc.ClientConn) EchoServerClient { + return &echoServerClient{cc} +} + +func (c *echoServerClient) Process(ctx context.Context, opts ...grpc.CallOption) (EchoServer_ProcessClient, error) { + stream, err := c.cc.NewStream(ctx, &_EchoServer_serviceDesc.Streams[0], "/middleware.EchoServer/Process", opts...) + if err != nil { + return nil, err + } + x := &echoServerProcessClient{stream} + return x, nil +} + +type EchoServer_ProcessClient interface { + Send(*Msg) error + Recv() (*Msg, error) + grpc.ClientStream +} + +type echoServerProcessClient struct { + grpc.ClientStream +} + +func (x *echoServerProcessClient) Send(m *Msg) error { + return x.ClientStream.SendMsg(m) +} + +func (x *echoServerProcessClient) Recv() (*Msg, error) { + m := new(Msg) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// EchoServerServer is the server API for EchoServer service. +type EchoServerServer interface { + Process(EchoServer_ProcessServer) error +} + +// UnimplementedEchoServerServer can be embedded to have forward compatible implementations. +type UnimplementedEchoServerServer struct { +} + +func (*UnimplementedEchoServerServer) Process(srv EchoServer_ProcessServer) error { + return status.Errorf(codes.Unimplemented, "method Process not implemented") +} + +func RegisterEchoServerServer(s *grpc.Server, srv EchoServerServer) { + s.RegisterService(&_EchoServer_serviceDesc, srv) +} + +func _EchoServer_Process_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(EchoServerServer).Process(&echoServerProcessServer{stream}) +} + +type EchoServer_ProcessServer interface { + Send(*Msg) error + Recv() (*Msg, error) + grpc.ServerStream +} + +type echoServerProcessServer struct { + grpc.ServerStream +} + +func (x *echoServerProcessServer) Send(m *Msg) error { + return x.ServerStream.SendMsg(m) +} + +func (x *echoServerProcessServer) Recv() (*Msg, error) { + m := new(Msg) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _EchoServer_serviceDesc = grpc.ServiceDesc{ + ServiceName: "middleware.EchoServer", + HandlerType: (*EchoServerServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "Process", + Handler: _EchoServer_Process_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "echo_server.proto", +} + +func (m *Msg) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Msg) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Msg) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Body) > 0 { + i -= len(m.Body) + copy(dAtA[i:], m.Body) + i = encodeVarintEchoServer(dAtA, i, uint64(len(m.Body))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarintEchoServer(dAtA []byte, offset int, v uint64) int { + offset -= sovEchoServer(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *Msg) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Body) + if l > 0 { + n += 1 + l + sovEchoServer(uint64(l)) + } + return n +} + +func sovEchoServer(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozEchoServer(x uint64) (n int) { + return sovEchoServer(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (this *Msg) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&Msg{`, + `Body:` + fmt.Sprintf("%v", this.Body) + `,`, + `}`, + }, "") + return s +} +func valueToStringEchoServer(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("*%v", pv) +} +func (m *Msg) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowEchoServer + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Msg: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Msg: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Body", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowEchoServer + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthEchoServer + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthEchoServer + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Body = append(m.Body[:0], dAtA[iNdEx:postIndex]...) + if m.Body == nil { + m.Body = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipEchoServer(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthEchoServer + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthEchoServer + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipEchoServer(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowEchoServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowEchoServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowEchoServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthEchoServer + } + iNdEx += length + if iNdEx < 0 { + return 0, ErrInvalidLengthEchoServer + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowEchoServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipEchoServer(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + if iNdEx < 0 { + return 0, ErrInvalidLengthEchoServer + } + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthEchoServer = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowEchoServer = fmt.Errorf("proto: integer overflow") +) diff --git a/middleware/middleware_test/echo_server.proto b/middleware/middleware_test/echo_server.proto new file mode 100644 index 000000000..4c0946d97 --- /dev/null +++ b/middleware/middleware_test/echo_server.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package middleware; + +option go_package = "middleware_test"; + +service EchoServer { + rpc Process(stream Msg) returns (stream Msg) {}; +} + +message Msg { + bytes body = 1; +} diff --git a/middleware/path_rewrite.go b/middleware/path_rewrite.go new file mode 100644 index 000000000..c9e917a72 --- /dev/null +++ b/middleware/path_rewrite.go @@ -0,0 +1,56 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/path_rewrite.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "net/http" + "net/url" + "regexp" + + log "github.com/sirupsen/logrus" +) + +// PathRewrite supports regex matching and replace on Request URIs +func PathRewrite(regexp *regexp.Regexp, replacement string) Interface { + return pathRewrite{ + regexp: regexp, + replacement: replacement, + } +} + +type pathRewrite struct { + regexp *regexp.Regexp + replacement string +} + +func (p pathRewrite) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.RequestURI = p.regexp.ReplaceAllString(r.RequestURI, p.replacement) + r.URL.RawPath = p.regexp.ReplaceAllString(r.URL.EscapedPath(), p.replacement) + path, err := url.PathUnescape(r.URL.RawPath) + if err != nil { + log.Errorf("Got invalid url-encoded path %v after applying path rewrite %v: %v", r.URL.RawPath, p, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + r.URL.Path = path + next.ServeHTTP(w, r) + }) +} + +// PathReplace replcase Request.RequestURI with the specified string. +func PathReplace(replacement string) Interface { + return pathReplace(replacement) +} + +type pathReplace string + +func (p pathReplace) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = string(p) + r.RequestURI = string(p) + next.ServeHTTP(w, r) + }) +} diff --git a/middleware/response.go b/middleware/response.go new file mode 100644 index 000000000..e2ce1d0a7 --- /dev/null +++ b/middleware/response.go @@ -0,0 +1,127 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/response.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" +) + +const ( + maxResponseBodyInLogs = 4096 // At most 4k bytes from response bodies in our logs. +) + +type badResponseLoggingWriter interface { + http.ResponseWriter + getStatusCode() int + getWriteError() error +} + +// nonFlushingBadResponseLoggingWriter writes the body of "bad" responses (i.e. 5xx +// responses) to a buffer. +type nonFlushingBadResponseLoggingWriter struct { + rw http.ResponseWriter + buffer io.Writer + logBody bool + bodyBytesLeft int + statusCode int + writeError error // The error returned when downstream Write() fails. +} + +// flushingBadResponseLoggingWriter is a badResponseLoggingWriter that +// implements http.Flusher. +type flushingBadResponseLoggingWriter struct { + nonFlushingBadResponseLoggingWriter + f http.Flusher +} + +func newBadResponseLoggingWriter(rw http.ResponseWriter, buffer io.Writer) badResponseLoggingWriter { + b := nonFlushingBadResponseLoggingWriter{ + rw: rw, + buffer: buffer, + logBody: false, + bodyBytesLeft: maxResponseBodyInLogs, + statusCode: http.StatusOK, + } + + if f, ok := rw.(http.Flusher); ok { + return &flushingBadResponseLoggingWriter{b, f} + } + + return &b +} + +// Unwrap method is used by http.ResponseController to get access to original http.ResponseWriter. +func (b *nonFlushingBadResponseLoggingWriter) Unwrap() http.ResponseWriter { + return b.rw +} + +// Header returns the header map that will be sent by WriteHeader. +// Implements ResponseWriter. +func (b *nonFlushingBadResponseLoggingWriter) Header() http.Header { + return b.rw.Header() +} + +// Write writes HTTP response data. +func (b *nonFlushingBadResponseLoggingWriter) Write(data []byte) (int, error) { + if b.statusCode == 0 { + // WriteHeader has (probably) not been called, so we need to call it with StatusOK to fulfill the interface contract. + // https://godoc.org/net/http#ResponseWriter + b.WriteHeader(http.StatusOK) + } + n, err := b.rw.Write(data) + if b.logBody { + b.captureResponseBody(data) + } + if err != nil { + b.writeError = err + } + return n, err +} + +// WriteHeader writes the HTTP response header. +func (b *nonFlushingBadResponseLoggingWriter) WriteHeader(statusCode int) { + b.statusCode = statusCode + if statusCode >= 500 { + b.logBody = true + } + b.rw.WriteHeader(statusCode) +} + +// Hijack hijacks the first response writer that is a Hijacker. +func (b *nonFlushingBadResponseLoggingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := b.rw.(http.Hijacker) + if ok { + return hj.Hijack() + } + return nil, nil, fmt.Errorf("badResponseLoggingWriter: can't cast underlying response writer to Hijacker") +} + +func (b *nonFlushingBadResponseLoggingWriter) getStatusCode() int { + return b.statusCode +} + +func (b *nonFlushingBadResponseLoggingWriter) getWriteError() error { + return b.writeError +} + +func (b *nonFlushingBadResponseLoggingWriter) captureResponseBody(data []byte) { + if len(data) > b.bodyBytesLeft { + _, _ = b.buffer.Write(data[:b.bodyBytesLeft]) + _, _ = io.WriteString(b.buffer, "...") + b.bodyBytesLeft = 0 + b.logBody = false + } else { + _, _ = b.buffer.Write(data) + b.bodyBytesLeft -= len(data) + } +} + +func (b *flushingBadResponseLoggingWriter) Flush() { + b.f.Flush() +} diff --git a/middleware/response_test.go b/middleware/response_test.go new file mode 100644 index 000000000..2846276ad --- /dev/null +++ b/middleware/response_test.go @@ -0,0 +1,97 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/response_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBadResponseLoggingWriter(t *testing.T) { + for _, tc := range []struct { + statusCode int + data string + expected string + }{ + {http.StatusOK, "", ""}, + {http.StatusOK, "some data", ""}, + {http.StatusUnprocessableEntity, "unprocessable", ""}, + {http.StatusInternalServerError, "", ""}, + {http.StatusInternalServerError, "bad juju", "bad juju\n"}, + } { + w := httptest.NewRecorder() + var buf bytes.Buffer + wrapped := newBadResponseLoggingWriter(w, &buf) + switch { + case tc.data == "": + wrapped.WriteHeader(tc.statusCode) + case tc.statusCode < 300 && tc.data != "": + wrapped.WriteHeader(tc.statusCode) + _, err := wrapped.Write([]byte(tc.data)) + require.NoError(t, err) + default: + http.Error(wrapped, tc.data, tc.statusCode) + } + if wrapped.getStatusCode() != tc.statusCode { + t.Errorf("Wrong status code: have %d want %d", wrapped.getStatusCode(), tc.statusCode) + } + data := buf.String() + if data != tc.expected { + t.Errorf("Wrong data: have %q want %q", data, tc.expected) + } + } +} + +// nonFlushingResponseWriter implements http.ResponseWriter but does not implement http.Flusher +type nonFlushingResponseWriter struct{} + +func (rw *nonFlushingResponseWriter) Header() http.Header { + return nil +} + +func (rw *nonFlushingResponseWriter) Write(_ []byte) (int, error) { + return -1, nil +} + +func (rw *nonFlushingResponseWriter) WriteHeader(_ int) { +} + +func TestBadResponseLoggingWriter_WithAndWithoutFlusher(t *testing.T) { + var buf bytes.Buffer + + nf := newBadResponseLoggingWriter(&nonFlushingResponseWriter{}, &buf) + + _, ok := nf.(http.Flusher) + if ok { + t.Errorf("Should not be able to cast nf as an http.Flusher") + } + + rec := httptest.NewRecorder() + f := newBadResponseLoggingWriter(rec, &buf) + + ff, ok := f.(http.Flusher) + if !ok { + t.Errorf("Should be able to cast f as an http.Flusher") + } + + ff.Flush() + if !rec.Flushed { + t.Errorf("Flush should have worked but did not") + } +} + +type responseWriterWithUnwrap interface { + http.ResponseWriter + Unwrap() http.ResponseWriter +} + +// Verify that custom http.ResponseWriter implementations implement Unwrap() method, used by http.ResponseContoller. +var _ responseWriterWithUnwrap = &nonFlushingBadResponseLoggingWriter{} +var _ responseWriterWithUnwrap = &flushingBadResponseLoggingWriter{} +var _ responseWriterWithUnwrap = &errorInterceptor{} diff --git a/middleware/source_ips.go b/middleware/source_ips.go new file mode 100644 index 000000000..7c035ddbf --- /dev/null +++ b/middleware/source_ips.go @@ -0,0 +1,145 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/source_ips.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "fmt" + "net" + "net/http" + "regexp" + "strings" +) + +// Parts copied and changed from gorilla mux proxy_headers.go + +var ( + // De-facto standard header keys. + xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") + xRealIP = http.CanonicalHeaderKey("X-Real-IP") +) + +var ( + // RFC7239 defines a new "Forwarded: " header designed to replace the + // existing use of X-Forwarded-* headers. + // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43 + forwarded = http.CanonicalHeaderKey("Forwarded") + // Allows for a sub-match of the first value after 'for=' to the next + // comma, semi-colon or space. The match is case-insensitive. + forRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`) +) + +// SourceIPExtractor extracts the source IPs from a HTTP request +type SourceIPExtractor struct { + // The header to search for + header string + // A regex that extracts the IP address from the header. + // It should contain at least one capturing group the first of which will be returned. + regex *regexp.Regexp +} + +// NewSourceIPs creates a new SourceIPs +func NewSourceIPs(header, regex string) (*SourceIPExtractor, error) { + if (header == "" && regex != "") || (header != "" && regex == "") { + return nil, fmt.Errorf("either both a header field and a regex have to be given or neither") + } + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("invalid regex given") + } + + return &SourceIPExtractor{ + header: header, + regex: re, + }, nil +} + +// extractHost returns the Host IP address without any port information +func extractHost(address string) string { + hostIP := net.ParseIP(address) + if hostIP != nil { + return hostIP.String() + } + var err error + hostStr, _, err := net.SplitHostPort(address) + if err != nil { + // Invalid IP address, just return it so it shows up in the logs + return address + } + return hostStr +} + +// Get returns any source addresses we can find in the request, comma-separated +func (sips SourceIPExtractor) Get(req *http.Request) string { + fwd := extractHost(sips.getIP(req)) + if fwd == "" { + if req.RemoteAddr == "" { + return "" + } + return extractHost(req.RemoteAddr) + } + // If RemoteAddr is empty just return the header + if req.RemoteAddr == "" { + return fwd + } + remoteIP := extractHost(req.RemoteAddr) + if fwd == remoteIP { + return remoteIP + } + // If both a header and RemoteAddr are present return them both, stripping off any port info from the RemoteAddr + return fmt.Sprintf("%v, %v", fwd, remoteIP) +} + +// getIP retrieves the IP from the RFC7239 Forwarded headers, +// X-Real-IP and X-Forwarded-For (in that order) or from the +// custom regex. +func (sips SourceIPExtractor) getIP(r *http.Request) string { + var addr string + + // Use the custom regex only if it was setup + if sips.header != "" { + hdr := r.Header.Get(sips.header) + if hdr == "" { + return "" + } + allMatches := sips.regex.FindAllStringSubmatch(hdr, 1) + if len(allMatches) == 0 { + return "" + } + firstMatch := allMatches[0] + // Check there is at least 1 submatch + if len(firstMatch) < 2 { + return "" + } + return firstMatch[1] + } + + if fwd := r.Header.Get(forwarded); fwd != "" { + // match should contain at least two elements if the protocol was + // specified in the Forwarded header. The first element will always be + // the 'for=' capture, which we ignore. In the case of multiple IP + // addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only + // extract the first, which should be the client IP. + if match := forRegex.FindStringSubmatch(fwd); len(match) > 1 { + // IPv6 addresses in Forwarded headers are quoted-strings. We strip + // these quotes. + addr = strings.Trim(match[1], `"`) + } + } else if fwd := r.Header.Get(xRealIP); fwd != "" { + // X-Real-IP should only contain one IP address (the client making the + // request). + addr = fwd + } else if fwd := strings.ReplaceAll(r.Header.Get(xForwardedFor), " ", ""); fwd != "" { + // Only grab the first (client) address. Note that '192.168.0.1, + // 10.1.1.1' is a valid key for X-Forwarded-For where addresses after + // the first may represent forwarding proxies earlier in the chain. + s := strings.Index(fwd, ",") + if s == -1 { + s = len(fwd) + } + addr = fwd[:s] + } + + return addr +} diff --git a/middleware/source_ips_test.go b/middleware/source_ips_test.go new file mode 100644 index 000000000..bc6f74ff5 --- /dev/null +++ b/middleware/source_ips_test.go @@ -0,0 +1,270 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/middleware/source_ips_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package middleware + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetSourceIPs(t *testing.T) { + tests := []struct { + name string + req *http.Request + want string + }{ + { + name: "no header", + req: &http.Request{RemoteAddr: "192.168.1.100:3454"}, + want: "192.168.1.100", + }, + { + name: "no header and remote has no port", + req: &http.Request{RemoteAddr: "192.168.1.100"}, + want: "192.168.1.100", + }, + { + name: "no header, remote address is invalid", + req: &http.Request{RemoteAddr: "192.168.100"}, + want: "192.168.100", + }, + { + name: "X-Forwarded-For and single forward address", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "X-Forwarded-For and single forward address which is same as remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"192.168.1.100"}, + }, + }, + want: "192.168.1.100", + }, + { + name: "single IPv6 X-Forwarded-For address", + req: &http.Request{ + RemoteAddr: "[2001:db9::1]:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"2001:db8::1"}, + }, + }, + want: "2001:db8::1, 2001:db9::1", + }, + { + name: "single X-Forwarded-For address no RemoteAddr", + req: &http.Request{ + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1"}, + }, + }, + want: "172.16.1.1", + }, + { + name: "multiple X-Forwarded-For with remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1, 10.10.13.20"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "multiple X-Forwarded-For with remote and no spaces", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"172.16.1.1,10.10.13.20,10.11.16.46"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "multiple X-Forwarded-For with IPv6 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"[2001:db8:cafe::17]:4711, 10.10.13.20"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "no header, no remote", + req: &http.Request{}, + want: "", + }, + { + name: "X-Real-IP with IPv6 remote with port", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xRealIP): {"[2001:db8:cafe::17]:4711"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "X-Real-IP with IPv4 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xRealIP): {"192.169.1.200"}, + }, + }, + want: "192.169.1.200, 192.168.1.100", + }, + { + name: "X-Real-IP with IPv4 remote and X-Forwarded-For", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"[2001:db8:cafe::17]:4711, 10.10.13.20"}, + http.CanonicalHeaderKey(xRealIP): {"192.169.1.200"}, + }, + }, + want: "192.169.1.200, 192.168.1.100", + }, + { + name: "Forwarded with IPv4 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=192.169.1.200"}, + }, + }, + want: "192.169.1.200, 192.168.1.100", + }, + { + name: "Forwarded with IPv4 and proto and by fields", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=192.0.2.60;proto=http;by=203.0.113.43"}, + }, + }, + want: "192.0.2.60, 192.168.1.100", + }, + { + name: "Forwarded with IPv6 and IPv4 remote", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=[2001:db8:cafe::17]:4711,for=192.169.1.200"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "Forwarded with X-Real-IP and X-Forwarded-For", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(xForwardedFor): {"[2001:db8:cafe::17]:4711, 10.10.13.20"}, + http.CanonicalHeaderKey(xRealIP): {"192.169.1.200"}, + http.CanonicalHeaderKey(forwarded): {"for=[2001:db8:cafe::17]:4711,for=192.169.1.200"}, + }, + }, + want: "2001:db8:cafe::17, 192.168.1.100", + }, + { + name: "Forwarded returns hostname", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey(forwarded): {"for=workstation.local"}, + }, + }, + want: "workstation.local, 192.168.1.100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sourceIPs, err := NewSourceIPs("", "") + require.NoError(t, err) + + if got := sourceIPs.Get(tt.req); got != tt.want { + t.Errorf("GetSource() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetSourceIPsWithCustomRegex(t *testing.T) { + tests := []struct { + name string + req *http.Request + want string + }{ + { + name: "no header", + req: &http.Request{RemoteAddr: "192.168.1.100:3454"}, + want: "192.168.1.100", + }, + { + name: "No matching entry in the header", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey("SomeHeader"): {"not matching"}, + }, + }, + want: "192.168.1.100", + }, + { + name: "one matching entry in the header", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey("SomeHeader"): {"172.16.1.1"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + { + name: "multiple matching entries in the header, only first used", + req: &http.Request{ + RemoteAddr: "192.168.1.100:3454", + Header: map[string][]string{ + http.CanonicalHeaderKey("SomeHeader"): {"172.16.1.1", "172.16.2.1"}, + }, + }, + want: "172.16.1.1, 192.168.1.100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sourceIPs, err := NewSourceIPs("SomeHeader", "((?:[0-9]{1,3}\\.){3}[0-9]{1,3})") + require.NoError(t, err) + + if got := sourceIPs.Get(tt.req); got != tt.want { + t.Errorf("GetSource() = %v, want %v", got, tt.want) + } + }) + } +} +func TestInvalid(t *testing.T) { + sourceIPs, err := NewSourceIPs("Header", "") + require.Empty(t, sourceIPs) + require.Error(t, err) + + sourceIPs, err = NewSourceIPs("", "a(.*)b") + require.Empty(t, sourceIPs) + require.Error(t, err) + + sourceIPs, err = NewSourceIPs("Header", "[*") + require.Empty(t, sourceIPs) + require.Error(t, err) +} diff --git a/mtime/mtime.go b/mtime/mtime.go new file mode 100644 index 000000000..2b490f0a8 --- /dev/null +++ b/mtime/mtime.go @@ -0,0 +1,20 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/mtime/mtime.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package mtime + +import "time" + +// Now returns the current time. +var Now = func() time.Time { return time.Now() } + +// NowForce sets the time returned by Now to t. +func NowForce(t time.Time) { + Now = func() time.Time { return t } +} + +// NowReset makes Now returns the current time again. +func NowReset() { + Now = func() time.Time { return time.Now() } +} diff --git a/ring/client/pool.go b/ring/client/pool.go index ee9d06b6b..5e21e69c3 100644 --- a/ring/client/pool.go +++ b/ring/client/pool.go @@ -10,12 +10,12 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/prometheus/client_golang/prometheus" - "github.com/weaveworks/common/user" "google.golang.org/grpc/health/grpc_health_v1" "github.com/grafana/dskit/concurrency" "github.com/grafana/dskit/internal/slices" "github.com/grafana/dskit/services" + "github.com/grafana/dskit/user" ) // PoolClient is the interface that should be implemented by a diff --git a/ring/ring_test.go b/ring/ring_test.go index fc294fe98..45228c3e1 100644 --- a/ring/ring_test.go +++ b/ring/ring_test.go @@ -18,11 +18,11 @@ import ( "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/weaveworks/common/httpgrpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/grafana/dskit/flagext" + "github.com/grafana/dskit/httpgrpc" dsmath "github.com/grafana/dskit/internal/math" "github.com/grafana/dskit/internal/slices" "github.com/grafana/dskit/kv" diff --git a/server/certs/genCerts.sh b/server/certs/genCerts.sh new file mode 100644 index 000000000..f2f52fc33 --- /dev/null +++ b/server/certs/genCerts.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# From https://github.com/joe-elliott/cert-exporter/blob/69d3d7230378325a1de4fa313432d3d6ced4a518/test/files/genCerts.sh +certFolder=$1 +days=$2 + +pushd "$certFolder" + +# keys +openssl genrsa -out root.key +openssl genrsa -out client.key +openssl genrsa -out server.key + +# root cert +openssl req -x509 -new -nodes -key root.key -subj "/C=US/ST=KY/O=Org/CN=root" -sha256 -days "$days" -out root.crt + +# csrs +openssl req -new -sha256 -key client.key -subj "/C=US/ST=KY/O=Org/CN=client" -out client.csr +openssl req -new -sha256 -key server.key -subj "/C=US/ST=KY/O=Org/CN=localhost" -out server.csr + +openssl x509 -req -in client.csr -CA root.crt -CAkey root.key -CAcreateserial -out client.crt -days "$days" -sha256 +openssl x509 -req -in server.csr -CA root.crt -CAkey root.key -CAcreateserial -out server.crt -days "$days" -sha256 + +popd diff --git a/server/fake_server.pb.go b/server/fake_server.pb.go new file mode 100644 index 000000000..75ee6b0a1 --- /dev/null +++ b/server/fake_server.pb.go @@ -0,0 +1,653 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: fake_server.proto + +package server + +import ( + context "context" + fmt "fmt" + proto "github.com/gogo/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + io "io" + math "math" + math_bits "math/bits" + reflect "reflect" + strings "strings" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type FailWithHTTPErrorRequest struct { + Code int32 `protobuf:"varint,1,opt,name=Code,proto3" json:"Code,omitempty"` +} + +func (m *FailWithHTTPErrorRequest) Reset() { *m = FailWithHTTPErrorRequest{} } +func (*FailWithHTTPErrorRequest) ProtoMessage() {} +func (*FailWithHTTPErrorRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_a932e7b7b9f5c118, []int{0} +} +func (m *FailWithHTTPErrorRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *FailWithHTTPErrorRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_FailWithHTTPErrorRequest.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *FailWithHTTPErrorRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_FailWithHTTPErrorRequest.Merge(m, src) +} +func (m *FailWithHTTPErrorRequest) XXX_Size() int { + return m.Size() +} +func (m *FailWithHTTPErrorRequest) XXX_DiscardUnknown() { + xxx_messageInfo_FailWithHTTPErrorRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_FailWithHTTPErrorRequest proto.InternalMessageInfo + +func (m *FailWithHTTPErrorRequest) GetCode() int32 { + if m != nil { + return m.Code + } + return 0 +} + +func init() { + proto.RegisterType((*FailWithHTTPErrorRequest)(nil), "server.FailWithHTTPErrorRequest") +} + +func init() { proto.RegisterFile("fake_server.proto", fileDescriptor_a932e7b7b9f5c118) } + +var fileDescriptor_a932e7b7b9f5c118 = []byte{ + // 265 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4c, 0x4b, 0xcc, 0x4e, + 0x8d, 0x2f, 0x4e, 0x2d, 0x2a, 0x4b, 0x2d, 0xd2, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x83, + 0xf0, 0xa4, 0xa4, 0xd3, 0xf3, 0xf3, 0xd3, 0x73, 0x52, 0xf5, 0xc1, 0xa2, 0x49, 0xa5, 0x69, 0xfa, + 0xa9, 0xb9, 0x05, 0x25, 0x95, 0x10, 0x45, 0x4a, 0x7a, 0x5c, 0x12, 0x6e, 0x89, 0x99, 0x39, 0xe1, + 0x99, 0x25, 0x19, 0x1e, 0x21, 0x21, 0x01, 0xae, 0x45, 0x45, 0xf9, 0x45, 0x41, 0xa9, 0x85, 0xa5, + 0xa9, 0xc5, 0x25, 0x42, 0x42, 0x5c, 0x2c, 0xce, 0xf9, 0x29, 0xa9, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, + 0xac, 0x41, 0x60, 0xb6, 0xd1, 0x6d, 0x26, 0x2e, 0x2e, 0xb7, 0xc4, 0xec, 0xd4, 0x60, 0xb0, 0xd9, + 0x42, 0xd6, 0x5c, 0xec, 0xc1, 0xa5, 0xc9, 0xc9, 0xa9, 0xa9, 0x29, 0x42, 0x62, 0x7a, 0x10, 0x7b, + 0xf4, 0x60, 0xf6, 0xe8, 0xb9, 0x82, 0xec, 0x91, 0xc2, 0x21, 0xae, 0xc4, 0x20, 0xe4, 0xc8, 0xc5, + 0x0b, 0xb3, 0x1b, 0x6c, 0x2f, 0x19, 0x46, 0xf8, 0x73, 0x09, 0x62, 0x38, 0x5f, 0x48, 0x41, 0x0f, + 0x1a, 0x0e, 0xb8, 0x7c, 0x86, 0xc7, 0x40, 0x4b, 0x2e, 0xd6, 0xe0, 0x9c, 0xd4, 0xd4, 0x02, 0xb2, + 0xbc, 0xc3, 0x1d, 0x5c, 0x52, 0x94, 0x9a, 0x98, 0x4b, 0xa6, 0x01, 0x06, 0x8c, 0x4e, 0x26, 0x17, + 0x1e, 0xca, 0x31, 0xdc, 0x78, 0x28, 0xc7, 0xf0, 0xe1, 0xa1, 0x1c, 0x63, 0xc3, 0x23, 0x39, 0xc6, + 0x15, 0x8f, 0xe4, 0x18, 0x4f, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e, 0xf1, 0xc1, 0x23, 0x39, + 0xc6, 0x17, 0x8f, 0xe4, 0x18, 0x3e, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1, 0xc2, 0x63, + 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0x92, 0xd8, 0xc0, 0x26, 0x19, 0x03, 0x02, 0x00, 0x00, 0xff, + 0xff, 0x43, 0x2b, 0x71, 0x6d, 0x04, 0x02, 0x00, 0x00, +} + +func (this *FailWithHTTPErrorRequest) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + that1, ok := that.(*FailWithHTTPErrorRequest) + if !ok { + that2, ok := that.(FailWithHTTPErrorRequest) + if ok { + that1 = &that2 + } else { + return false + } + } + if that1 == nil { + return this == nil + } else if this == nil { + return false + } + if this.Code != that1.Code { + return false + } + return true +} +func (this *FailWithHTTPErrorRequest) GoString() string { + if this == nil { + return "nil" + } + s := make([]string, 0, 5) + s = append(s, "&server.FailWithHTTPErrorRequest{") + s = append(s, "Code: "+fmt.Sprintf("%#v", this.Code)+",\n") + s = append(s, "}") + return strings.Join(s, "") +} +func valueToGoStringFakeServer(v interface{}, typ string) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("func(v %v) *%v { return &v } ( %#v )", typ, typ, pv) +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// FakeServerClient is the client API for FakeServer service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type FakeServerClient interface { + Succeed(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) + FailWithError(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) + FailWithHTTPError(ctx context.Context, in *FailWithHTTPErrorRequest, opts ...grpc.CallOption) (*empty.Empty, error) + Sleep(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) + StreamSleep(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (FakeServer_StreamSleepClient, error) +} + +type fakeServerClient struct { + cc *grpc.ClientConn +} + +func NewFakeServerClient(cc *grpc.ClientConn) FakeServerClient { + return &fakeServerClient{cc} +} + +func (c *fakeServerClient) Succeed(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/server.FakeServer/Succeed", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *fakeServerClient) FailWithError(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/server.FakeServer/FailWithError", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *fakeServerClient) FailWithHTTPError(ctx context.Context, in *FailWithHTTPErrorRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/server.FakeServer/FailWithHTTPError", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *fakeServerClient) Sleep(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/server.FakeServer/Sleep", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *fakeServerClient) StreamSleep(ctx context.Context, in *empty.Empty, opts ...grpc.CallOption) (FakeServer_StreamSleepClient, error) { + stream, err := c.cc.NewStream(ctx, &_FakeServer_serviceDesc.Streams[0], "/server.FakeServer/StreamSleep", opts...) + if err != nil { + return nil, err + } + x := &fakeServerStreamSleepClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type FakeServer_StreamSleepClient interface { + Recv() (*empty.Empty, error) + grpc.ClientStream +} + +type fakeServerStreamSleepClient struct { + grpc.ClientStream +} + +func (x *fakeServerStreamSleepClient) Recv() (*empty.Empty, error) { + m := new(empty.Empty) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// FakeServerServer is the server API for FakeServer service. +type FakeServerServer interface { + Succeed(context.Context, *empty.Empty) (*empty.Empty, error) + FailWithError(context.Context, *empty.Empty) (*empty.Empty, error) + FailWithHTTPError(context.Context, *FailWithHTTPErrorRequest) (*empty.Empty, error) + Sleep(context.Context, *empty.Empty) (*empty.Empty, error) + StreamSleep(*empty.Empty, FakeServer_StreamSleepServer) error +} + +// UnimplementedFakeServerServer can be embedded to have forward compatible implementations. +type UnimplementedFakeServerServer struct { +} + +func (*UnimplementedFakeServerServer) Succeed(ctx context.Context, req *empty.Empty) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Succeed not implemented") +} +func (*UnimplementedFakeServerServer) FailWithError(ctx context.Context, req *empty.Empty) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method FailWithError not implemented") +} +func (*UnimplementedFakeServerServer) FailWithHTTPError(ctx context.Context, req *FailWithHTTPErrorRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method FailWithHTTPError not implemented") +} +func (*UnimplementedFakeServerServer) Sleep(ctx context.Context, req *empty.Empty) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Sleep not implemented") +} +func (*UnimplementedFakeServerServer) StreamSleep(req *empty.Empty, srv FakeServer_StreamSleepServer) error { + return status.Errorf(codes.Unimplemented, "method StreamSleep not implemented") +} + +func RegisterFakeServerServer(s *grpc.Server, srv FakeServerServer) { + s.RegisterService(&_FakeServer_serviceDesc, srv) +} + +func _FakeServer_Succeed_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(empty.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(FakeServerServer).Succeed(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/server.FakeServer/Succeed", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(FakeServerServer).Succeed(ctx, req.(*empty.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _FakeServer_FailWithError_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(empty.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(FakeServerServer).FailWithError(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/server.FakeServer/FailWithError", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(FakeServerServer).FailWithError(ctx, req.(*empty.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _FakeServer_FailWithHTTPError_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(FailWithHTTPErrorRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(FakeServerServer).FailWithHTTPError(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/server.FakeServer/FailWithHTTPError", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(FakeServerServer).FailWithHTTPError(ctx, req.(*FailWithHTTPErrorRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _FakeServer_Sleep_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(empty.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(FakeServerServer).Sleep(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/server.FakeServer/Sleep", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(FakeServerServer).Sleep(ctx, req.(*empty.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _FakeServer_StreamSleep_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(empty.Empty) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(FakeServerServer).StreamSleep(m, &fakeServerStreamSleepServer{stream}) +} + +type FakeServer_StreamSleepServer interface { + Send(*empty.Empty) error + grpc.ServerStream +} + +type fakeServerStreamSleepServer struct { + grpc.ServerStream +} + +func (x *fakeServerStreamSleepServer) Send(m *empty.Empty) error { + return x.ServerStream.SendMsg(m) +} + +var _FakeServer_serviceDesc = grpc.ServiceDesc{ + ServiceName: "server.FakeServer", + HandlerType: (*FakeServerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Succeed", + Handler: _FakeServer_Succeed_Handler, + }, + { + MethodName: "FailWithError", + Handler: _FakeServer_FailWithError_Handler, + }, + { + MethodName: "FailWithHTTPError", + Handler: _FakeServer_FailWithHTTPError_Handler, + }, + { + MethodName: "Sleep", + Handler: _FakeServer_Sleep_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "StreamSleep", + Handler: _FakeServer_StreamSleep_Handler, + ServerStreams: true, + }, + }, + Metadata: "fake_server.proto", +} + +func (m *FailWithHTTPErrorRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *FailWithHTTPErrorRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *FailWithHTTPErrorRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Code != 0 { + i = encodeVarintFakeServer(dAtA, i, uint64(m.Code)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func encodeVarintFakeServer(dAtA []byte, offset int, v uint64) int { + offset -= sovFakeServer(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *FailWithHTTPErrorRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Code != 0 { + n += 1 + sovFakeServer(uint64(m.Code)) + } + return n +} + +func sovFakeServer(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozFakeServer(x uint64) (n int) { + return sovFakeServer(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (this *FailWithHTTPErrorRequest) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&FailWithHTTPErrorRequest{`, + `Code:` + fmt.Sprintf("%v", this.Code) + `,`, + `}`, + }, "") + return s +} +func valueToStringFakeServer(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("*%v", pv) +} +func (m *FailWithHTTPErrorRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowFakeServer + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: FailWithHTTPErrorRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: FailWithHTTPErrorRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Code", wireType) + } + m.Code = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowFakeServer + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Code |= int32(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipFakeServer(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthFakeServer + } + if (iNdEx + skippy) < 0 { + return ErrInvalidLengthFakeServer + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipFakeServer(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowFakeServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowFakeServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowFakeServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthFakeServer + } + iNdEx += length + if iNdEx < 0 { + return 0, ErrInvalidLengthFakeServer + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowFakeServer + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipFakeServer(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + if iNdEx < 0 { + return 0, ErrInvalidLengthFakeServer + } + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthFakeServer = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowFakeServer = fmt.Errorf("proto: integer overflow") +) diff --git a/server/fake_server.proto b/server/fake_server.proto new file mode 100644 index 000000000..248a6f244 --- /dev/null +++ b/server/fake_server.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package server; + +import "google/protobuf/empty.proto"; + +service FakeServer { + rpc Succeed(google.protobuf.Empty) returns (google.protobuf.Empty) {}; + rpc FailWithError(google.protobuf.Empty) returns (google.protobuf.Empty) {}; + rpc FailWithHTTPError(FailWithHTTPErrorRequest) returns (google.protobuf.Empty) {}; + rpc Sleep(google.protobuf.Empty) returns (google.protobuf.Empty) {}; + rpc StreamSleep(google.protobuf.Empty) returns (stream google.protobuf.Empty) {}; +} + +message FailWithHTTPErrorRequest { + int32 Code = 1; +} diff --git a/server/metrics.go b/server/metrics.go new file mode 100644 index 000000000..aa1c3e53a --- /dev/null +++ b/server/metrics.go @@ -0,0 +1,67 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/server/metrics.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package server + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/grafana/dskit/instrument" + "github.com/grafana/dskit/middleware" +) + +type Metrics struct { + TCPConnections *prometheus.GaugeVec + TCPConnectionsLimit *prometheus.GaugeVec + RequestDuration *prometheus.HistogramVec + ReceivedMessageSize *prometheus.HistogramVec + SentMessageSize *prometheus.HistogramVec + InflightRequests *prometheus.GaugeVec +} + +func NewServerMetrics(cfg Config) *Metrics { + reg := promauto.With(cfg.registererOrDefault()) + + return &Metrics{ + TCPConnections: reg.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: cfg.MetricsNamespace, + Name: "tcp_connections", + Help: "Current number of accepted TCP connections.", + }, []string{"protocol"}), + TCPConnectionsLimit: reg.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: cfg.MetricsNamespace, + Name: "tcp_connections_limit", + Help: "The max number of TCP connections that can be accepted (0 means no limit).", + }, []string{"protocol"}), + RequestDuration: reg.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: cfg.MetricsNamespace, + Name: "request_duration_seconds", + Help: "Time (in seconds) spent serving HTTP requests.", + Buckets: instrument.DefBuckets, + NativeHistogramBucketFactor: cfg.MetricsNativeHistogramFactor, + NativeHistogramMaxBucketNumber: 100, + NativeHistogramMinResetDuration: time.Hour, + }, []string{"method", "route", "status_code", "ws"}), + ReceivedMessageSize: reg.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: cfg.MetricsNamespace, + Name: "request_message_bytes", + Help: "Size (in bytes) of messages received in the request.", + Buckets: middleware.BodySizeBuckets, + }, []string{"method", "route"}), + SentMessageSize: reg.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: cfg.MetricsNamespace, + Name: "response_message_bytes", + Help: "Size (in bytes) of messages sent in response.", + Buckets: middleware.BodySizeBuckets, + }, []string{"method", "route"}), + InflightRequests: reg.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: cfg.MetricsNamespace, + Name: "inflight_requests", + Help: "Current number of inflight requests.", + }, []string{"method", "route"}), + } +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 000000000..9275637cf --- /dev/null +++ b/server/server.go @@ -0,0 +1,563 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/server/server.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package server + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "math" + "net" + "net/http" + _ "net/http/pprof" // anonymous import to get the pprof handler registered + "strings" + "time" + + "github.com/gorilla/mux" + otgrpc "github.com/opentracing-contrib/go-grpc" + "github.com/opentracing/opentracing-go" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/prometheus/exporter-toolkit/web" + "github.com/soheilhy/cmux" + "golang.org/x/net/netutil" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/keepalive" + + "github.com/grafana/dskit/httpgrpc" + httpgrpc_server "github.com/grafana/dskit/httpgrpc/server" + "github.com/grafana/dskit/log" + "github.com/grafana/dskit/middleware" + "github.com/grafana/dskit/signals" +) + +// Listen on the named network +const ( + // DefaultNetwork the host resolves to multiple IP addresses, + // Dial will try each IP address in order until one succeeds + DefaultNetwork = "tcp" + // NetworkTCPV4 for IPV4 only + NetworkTCPV4 = "tcp4" +) + +// SignalHandler used by Server. +type SignalHandler interface { + // Starts the signals handler. This method is blocking, and returns only after signal is received, + // or "Stop" is called. + Loop() + + // Stop blocked "Loop" method. + Stop() +} + +// TLSConfig contains TLS parameters for Config. +type TLSConfig struct { + TLSCertPath string `yaml:"cert_file"` + TLSKeyPath string `yaml:"key_file"` + ClientAuth string `yaml:"client_auth_type"` + ClientCAs string `yaml:"client_ca_file"` +} + +// Config for a Server +type Config struct { + MetricsNamespace string `yaml:"-"` + // Set to > 1 to add native histograms to requestDuration. + // See documentation for NativeHistogramBucketFactor in + // https://pkg.go.dev/github.com/prometheus/client_golang/prometheus#HistogramOpts + // for details. A generally useful value is 1.1. + MetricsNativeHistogramFactor float64 `yaml:"-"` + + HTTPListenNetwork string `yaml:"http_listen_network"` + HTTPListenAddress string `yaml:"http_listen_address"` + HTTPListenPort int `yaml:"http_listen_port"` + HTTPConnLimit int `yaml:"http_listen_conn_limit"` + GRPCListenNetwork string `yaml:"grpc_listen_network"` + GRPCListenAddress string `yaml:"grpc_listen_address"` + GRPCListenPort int `yaml:"grpc_listen_port"` + GRPCConnLimit int `yaml:"grpc_listen_conn_limit"` + + CipherSuites string `yaml:"tls_cipher_suites"` + MinVersion string `yaml:"tls_min_version"` + HTTPTLSConfig TLSConfig `yaml:"http_tls_config"` + GRPCTLSConfig TLSConfig `yaml:"grpc_tls_config"` + + RegisterInstrumentation bool `yaml:"register_instrumentation"` + ExcludeRequestInLog bool `yaml:"-"` + DisableRequestSuccessLog bool `yaml:"-"` + + ServerGracefulShutdownTimeout time.Duration `yaml:"graceful_shutdown_timeout"` + HTTPServerReadTimeout time.Duration `yaml:"http_server_read_timeout"` + HTTPServerWriteTimeout time.Duration `yaml:"http_server_write_timeout"` + HTTPServerIdleTimeout time.Duration `yaml:"http_server_idle_timeout"` + + GRPCOptions []grpc.ServerOption `yaml:"-"` + GRPCMiddleware []grpc.UnaryServerInterceptor `yaml:"-"` + GRPCStreamMiddleware []grpc.StreamServerInterceptor `yaml:"-"` + HTTPMiddleware []middleware.Interface `yaml:"-"` + Router *mux.Router `yaml:"-"` + DoNotAddDefaultHTTPMiddleware bool `yaml:"-"` + RouteHTTPToGRPC bool `yaml:"-"` + + GPRCServerMaxRecvMsgSize int `yaml:"grpc_server_max_recv_msg_size"` + GRPCServerMaxSendMsgSize int `yaml:"grpc_server_max_send_msg_size"` + GPRCServerMaxConcurrentStreams uint `yaml:"grpc_server_max_concurrent_streams"` + GRPCServerMaxConnectionIdle time.Duration `yaml:"grpc_server_max_connection_idle"` + GRPCServerMaxConnectionAge time.Duration `yaml:"grpc_server_max_connection_age"` + GRPCServerMaxConnectionAgeGrace time.Duration `yaml:"grpc_server_max_connection_age_grace"` + GRPCServerTime time.Duration `yaml:"grpc_server_keepalive_time"` + GRPCServerTimeout time.Duration `yaml:"grpc_server_keepalive_timeout"` + GRPCServerMinTimeBetweenPings time.Duration `yaml:"grpc_server_min_time_between_pings"` + GRPCServerPingWithoutStreamAllowed bool `yaml:"grpc_server_ping_without_stream_allowed"` + + LogFormat log.Format `yaml:"log_format"` + LogLevel log.Level `yaml:"log_level"` + Log log.Interface `yaml:"-"` + LogSourceIPs bool `yaml:"log_source_ips_enabled"` + LogSourceIPsHeader string `yaml:"log_source_ips_header"` + LogSourceIPsRegex string `yaml:"log_source_ips_regex"` + LogRequestHeaders bool `yaml:"log_request_headers"` + LogRequestAtInfoLevel bool `yaml:"log_request_at_info_level_enabled"` + LogRequestExcludeHeadersList string `yaml:"log_request_exclude_headers_list"` + + // If not set, default signal handler is used. + SignalHandler SignalHandler `yaml:"-"` + + // If not set, default Prometheus registry is used. + Registerer prometheus.Registerer `yaml:"-"` + Gatherer prometheus.Gatherer `yaml:"-"` + + PathPrefix string `yaml:"http_path_prefix"` +} + +var infinty = time.Duration(math.MaxInt64) + +// RegisterFlags adds the flags required to config this to the given FlagSet +func (cfg *Config) RegisterFlags(f *flag.FlagSet) { + f.StringVar(&cfg.HTTPListenAddress, "server.http-listen-address", "", "HTTP server listen address.") + f.StringVar(&cfg.HTTPListenNetwork, "server.http-listen-network", DefaultNetwork, "HTTP server listen network, default tcp") + f.StringVar(&cfg.CipherSuites, "server.tls-cipher-suites", "", "Comma-separated list of cipher suites to use. If blank, the default Go cipher suites is used.") + f.StringVar(&cfg.MinVersion, "server.tls-min-version", "", "Minimum TLS version to use. Allowed values: VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13. If blank, the Go TLS minimum version is used.") + f.StringVar(&cfg.HTTPTLSConfig.TLSCertPath, "server.http-tls-cert-path", "", "HTTP server cert path.") + f.StringVar(&cfg.HTTPTLSConfig.TLSKeyPath, "server.http-tls-key-path", "", "HTTP server key path.") + f.StringVar(&cfg.HTTPTLSConfig.ClientAuth, "server.http-tls-client-auth", "", "HTTP TLS Client Auth type.") + f.StringVar(&cfg.HTTPTLSConfig.ClientCAs, "server.http-tls-ca-path", "", "HTTP TLS Client CA path.") + f.StringVar(&cfg.GRPCTLSConfig.TLSCertPath, "server.grpc-tls-cert-path", "", "GRPC TLS server cert path.") + f.StringVar(&cfg.GRPCTLSConfig.TLSKeyPath, "server.grpc-tls-key-path", "", "GRPC TLS server key path.") + f.StringVar(&cfg.GRPCTLSConfig.ClientAuth, "server.grpc-tls-client-auth", "", "GRPC TLS Client Auth type.") + f.StringVar(&cfg.GRPCTLSConfig.ClientCAs, "server.grpc-tls-ca-path", "", "GRPC TLS Client CA path.") + f.IntVar(&cfg.HTTPListenPort, "server.http-listen-port", 80, "HTTP server listen port.") + f.IntVar(&cfg.HTTPConnLimit, "server.http-conn-limit", 0, "Maximum number of simultaneous http connections, <=0 to disable") + f.StringVar(&cfg.GRPCListenNetwork, "server.grpc-listen-network", DefaultNetwork, "gRPC server listen network") + f.StringVar(&cfg.GRPCListenAddress, "server.grpc-listen-address", "", "gRPC server listen address.") + f.IntVar(&cfg.GRPCListenPort, "server.grpc-listen-port", 9095, "gRPC server listen port.") + f.IntVar(&cfg.GRPCConnLimit, "server.grpc-conn-limit", 0, "Maximum number of simultaneous grpc connections, <=0 to disable") + f.BoolVar(&cfg.RegisterInstrumentation, "server.register-instrumentation", true, "Register the intrumentation handlers (/metrics etc).") + f.DurationVar(&cfg.ServerGracefulShutdownTimeout, "server.graceful-shutdown-timeout", 30*time.Second, "Timeout for graceful shutdowns") + f.DurationVar(&cfg.HTTPServerReadTimeout, "server.http-read-timeout", 30*time.Second, "Read timeout for HTTP server") + f.DurationVar(&cfg.HTTPServerWriteTimeout, "server.http-write-timeout", 30*time.Second, "Write timeout for HTTP server") + f.DurationVar(&cfg.HTTPServerIdleTimeout, "server.http-idle-timeout", 120*time.Second, "Idle timeout for HTTP server") + f.IntVar(&cfg.GPRCServerMaxRecvMsgSize, "server.grpc-max-recv-msg-size-bytes", 4*1024*1024, "Limit on the size of a gRPC message this server can receive (bytes).") + f.IntVar(&cfg.GRPCServerMaxSendMsgSize, "server.grpc-max-send-msg-size-bytes", 4*1024*1024, "Limit on the size of a gRPC message this server can send (bytes).") + f.UintVar(&cfg.GPRCServerMaxConcurrentStreams, "server.grpc-max-concurrent-streams", 100, "Limit on the number of concurrent streams for gRPC calls (0 = unlimited)") + f.DurationVar(&cfg.GRPCServerMaxConnectionIdle, "server.grpc.keepalive.max-connection-idle", infinty, "The duration after which an idle connection should be closed. Default: infinity") + f.DurationVar(&cfg.GRPCServerMaxConnectionAge, "server.grpc.keepalive.max-connection-age", infinty, "The duration for the maximum amount of time a connection may exist before it will be closed. Default: infinity") + f.DurationVar(&cfg.GRPCServerMaxConnectionAgeGrace, "server.grpc.keepalive.max-connection-age-grace", infinty, "An additive period after max-connection-age after which the connection will be forcibly closed. Default: infinity") + f.DurationVar(&cfg.GRPCServerTime, "server.grpc.keepalive.time", time.Hour*2, "Duration after which a keepalive probe is sent in case of no activity over the connection., Default: 2h") + f.DurationVar(&cfg.GRPCServerTimeout, "server.grpc.keepalive.timeout", time.Second*20, "After having pinged for keepalive check, the duration after which an idle connection should be closed, Default: 20s") + f.DurationVar(&cfg.GRPCServerMinTimeBetweenPings, "server.grpc.keepalive.min-time-between-pings", 5*time.Minute, "Minimum amount of time a client should wait before sending a keepalive ping. If client sends keepalive ping more often, server will send GOAWAY and close the connection.") + f.BoolVar(&cfg.GRPCServerPingWithoutStreamAllowed, "server.grpc.keepalive.ping-without-stream-allowed", false, "If true, server allows keepalive pings even when there are no active streams(RPCs). If false, and client sends ping when there are no active streams, server will send GOAWAY and close the connection.") + f.StringVar(&cfg.PathPrefix, "server.path-prefix", "", "Base path to serve all API routes from (e.g. /v1/)") + cfg.LogFormat.RegisterFlags(f) + cfg.LogLevel.RegisterFlags(f) + f.BoolVar(&cfg.LogSourceIPs, "server.log-source-ips-enabled", false, "Optionally log the source IPs.") + f.StringVar(&cfg.LogSourceIPsHeader, "server.log-source-ips-header", "", "Header field storing the source IPs. Only used if server.log-source-ips-enabled is true. If not set the default Forwarded, X-Real-IP and X-Forwarded-For headers are used") + f.StringVar(&cfg.LogSourceIPsRegex, "server.log-source-ips-regex", "", "Regex for matching the source IPs. Only used if server.log-source-ips-enabled is true. If not set the default Forwarded, X-Real-IP and X-Forwarded-For headers are used") + f.BoolVar(&cfg.LogRequestHeaders, "server.log-request-headers", false, "Optionally log request headers.") + f.StringVar(&cfg.LogRequestExcludeHeadersList, "server.log-request-headers-exclude-list", "", "Comma separated list of headers to exclude from loggin. Only used if server.log-request-headers is true.") + f.BoolVar(&cfg.LogRequestAtInfoLevel, "server.log-request-at-info-level-enabled", false, "Optionally log requests at info level instead of debug level. Applies to request headers as well if server.log-request-headers is enabled.") +} + +func (cfg *Config) registererOrDefault() prometheus.Registerer { + // If user doesn't supply a Registerer/gatherer, use Prometheus' by default. + if cfg.Registerer != nil { + return cfg.Registerer + } + return prometheus.DefaultRegisterer +} + +// Server wraps a HTTP and gRPC server, and some common initialization. +// +// Servers will be automatically instrumented for Prometheus metrics. +type Server struct { + cfg Config + handler SignalHandler + grpcListener net.Listener + httpListener net.Listener + + // These fields are used to support grpc over the http server + // if RouteHTTPToGRPC is set. the fields are kept here + // so they can be initialized in New() and started in Run() + grpchttpmux cmux.CMux + grpcOnHTTPListener net.Listener + GRPCOnHTTPServer *grpc.Server + + HTTP *mux.Router + HTTPServer *http.Server + GRPC *grpc.Server + Log log.Interface + Registerer prometheus.Registerer + Gatherer prometheus.Gatherer +} + +// New makes a new Server. It will panic if the metrics cannot be registered. +func New(cfg Config) (*Server, error) { + metrics := NewServerMetrics(cfg) + return newServer(cfg, metrics) +} + +// NewWithMetrics makes a new Server using the provided Metrics. It will not attempt to register the metrics, +// the user is responsible for doing so. +func NewWithMetrics(cfg Config, metrics *Metrics) (*Server, error) { + return newServer(cfg, metrics) +} + +func newServer(cfg Config, metrics *Metrics) (*Server, error) { + // If user doesn't supply a logging implementation, by default instantiate + // logrus. + logger := cfg.Log + if logger == nil { + logger = log.NewLogrus(cfg.LogLevel) + } + + gatherer := cfg.Gatherer + if gatherer == nil { + gatherer = prometheus.DefaultGatherer + } + + network := cfg.HTTPListenNetwork + if network == "" { + network = DefaultNetwork + } + // Setup listeners first, so we can fail early if the port is in use. + httpListener, err := net.Listen(network, fmt.Sprintf("%s:%d", cfg.HTTPListenAddress, cfg.HTTPListenPort)) + if err != nil { + return nil, err + } + httpListener = middleware.CountingListener(httpListener, metrics.TCPConnections.WithLabelValues("http")) + + metrics.TCPConnectionsLimit.WithLabelValues("http").Set(float64(cfg.HTTPConnLimit)) + if cfg.HTTPConnLimit > 0 { + httpListener = netutil.LimitListener(httpListener, cfg.HTTPConnLimit) + } + + var grpcOnHTTPListener net.Listener + var grpchttpmux cmux.CMux + if cfg.RouteHTTPToGRPC { + grpchttpmux = cmux.New(httpListener) + + httpListener = grpchttpmux.Match(cmux.HTTP1Fast()) + grpcOnHTTPListener = grpchttpmux.Match(cmux.HTTP2()) + } + + network = cfg.GRPCListenNetwork + if network == "" { + network = DefaultNetwork + } + grpcListener, err := net.Listen(network, fmt.Sprintf("%s:%d", cfg.GRPCListenAddress, cfg.GRPCListenPort)) + if err != nil { + return nil, err + } + grpcListener = middleware.CountingListener(grpcListener, metrics.TCPConnections.WithLabelValues("grpc")) + + metrics.TCPConnectionsLimit.WithLabelValues("grpc").Set(float64(cfg.GRPCConnLimit)) + if cfg.GRPCConnLimit > 0 { + grpcListener = netutil.LimitListener(grpcListener, cfg.GRPCConnLimit) + } + + cipherSuites, err := stringToCipherSuites(cfg.CipherSuites) + if err != nil { + return nil, err + } + minVersion, err := stringToTLSVersion(cfg.MinVersion) + if err != nil { + return nil, err + } + + // Setup TLS + var httpTLSConfig *tls.Config + if len(cfg.HTTPTLSConfig.TLSCertPath) > 0 && len(cfg.HTTPTLSConfig.TLSKeyPath) > 0 { + // Note: ConfigToTLSConfig from prometheus/exporter-toolkit is awaiting security review. + httpTLSConfig, err = web.ConfigToTLSConfig(&web.TLSConfig{ + TLSCertPath: cfg.HTTPTLSConfig.TLSCertPath, + TLSKeyPath: cfg.HTTPTLSConfig.TLSKeyPath, + ClientAuth: cfg.HTTPTLSConfig.ClientAuth, + ClientCAs: cfg.HTTPTLSConfig.ClientCAs, + CipherSuites: cipherSuites, + MinVersion: minVersion, + }) + if err != nil { + return nil, fmt.Errorf("error generating http tls config: %v", err) + } + } + var grpcTLSConfig *tls.Config + if len(cfg.GRPCTLSConfig.TLSCertPath) > 0 && len(cfg.GRPCTLSConfig.TLSKeyPath) > 0 { + // Note: ConfigToTLSConfig from prometheus/exporter-toolkit is awaiting security review. + grpcTLSConfig, err = web.ConfigToTLSConfig(&web.TLSConfig{ + TLSCertPath: cfg.GRPCTLSConfig.TLSCertPath, + TLSKeyPath: cfg.GRPCTLSConfig.TLSKeyPath, + ClientAuth: cfg.GRPCTLSConfig.ClientAuth, + ClientCAs: cfg.GRPCTLSConfig.ClientCAs, + CipherSuites: cipherSuites, + MinVersion: minVersion, + }) + if err != nil { + return nil, fmt.Errorf("error generating grpc tls config: %v", err) + } + } + + logger.WithField("http", httpListener.Addr()).WithField("grpc", grpcListener.Addr()).Infof("server listening on addresses") + + // Setup gRPC server + serverLog := middleware.GRPCServerLog{ + Log: logger, + WithRequest: !cfg.ExcludeRequestInLog, + DisableRequestSuccessLog: cfg.DisableRequestSuccessLog, + } + grpcMiddleware := []grpc.UnaryServerInterceptor{ + serverLog.UnaryServerInterceptor, + otgrpc.OpenTracingServerInterceptor(opentracing.GlobalTracer()), + middleware.UnaryServerInstrumentInterceptor(metrics.RequestDuration), + } + grpcMiddleware = append(grpcMiddleware, cfg.GRPCMiddleware...) + + grpcStreamMiddleware := []grpc.StreamServerInterceptor{ + serverLog.StreamServerInterceptor, + otgrpc.OpenTracingStreamServerInterceptor(opentracing.GlobalTracer()), + middleware.StreamServerInstrumentInterceptor(metrics.RequestDuration), + } + grpcStreamMiddleware = append(grpcStreamMiddleware, cfg.GRPCStreamMiddleware...) + + grpcKeepAliveOptions := keepalive.ServerParameters{ + MaxConnectionIdle: cfg.GRPCServerMaxConnectionIdle, + MaxConnectionAge: cfg.GRPCServerMaxConnectionAge, + MaxConnectionAgeGrace: cfg.GRPCServerMaxConnectionAgeGrace, + Time: cfg.GRPCServerTime, + Timeout: cfg.GRPCServerTimeout, + } + + grpcKeepAliveEnforcementPolicy := keepalive.EnforcementPolicy{ + MinTime: cfg.GRPCServerMinTimeBetweenPings, + PermitWithoutStream: cfg.GRPCServerPingWithoutStreamAllowed, + } + + grpcOptions := []grpc.ServerOption{ + grpc.ChainUnaryInterceptor(grpcMiddleware...), + grpc.ChainStreamInterceptor(grpcStreamMiddleware...), + grpc.KeepaliveParams(grpcKeepAliveOptions), + grpc.KeepaliveEnforcementPolicy(grpcKeepAliveEnforcementPolicy), + grpc.MaxRecvMsgSize(cfg.GPRCServerMaxRecvMsgSize), + grpc.MaxSendMsgSize(cfg.GRPCServerMaxSendMsgSize), + grpc.MaxConcurrentStreams(uint32(cfg.GPRCServerMaxConcurrentStreams)), + grpc.StatsHandler(middleware.NewStatsHandler( + metrics.ReceivedMessageSize, + metrics.SentMessageSize, + metrics.InflightRequests, + )), + } + grpcOptions = append(grpcOptions, cfg.GRPCOptions...) + if grpcTLSConfig != nil { + grpcCreds := credentials.NewTLS(grpcTLSConfig) + grpcOptions = append(grpcOptions, grpc.Creds(grpcCreds)) + } + grpcServer := grpc.NewServer(grpcOptions...) + grpcOnHTTPServer := grpc.NewServer(grpcOptions...) + + // Setup HTTP server + var router *mux.Router + if cfg.Router != nil { + router = cfg.Router + } else { + router = mux.NewRouter() + } + if cfg.PathPrefix != "" { + // Expect metrics and pprof handlers to be prefixed with server's path prefix. + // e.g. /loki/metrics or /loki/debug/pprof + router = router.PathPrefix(cfg.PathPrefix).Subrouter() + } + if cfg.RegisterInstrumentation { + RegisterInstrumentationWithGatherer(router, gatherer) + } + + var sourceIPs *middleware.SourceIPExtractor + if cfg.LogSourceIPs { + sourceIPs, err = middleware.NewSourceIPs(cfg.LogSourceIPsHeader, cfg.LogSourceIPsRegex) + if err != nil { + return nil, fmt.Errorf("error setting up source IP extraction: %v", err) + } + } + + defaultLogMiddleware := middleware.NewLogMiddleware(logger, cfg.LogRequestHeaders, cfg.LogRequestAtInfoLevel, sourceIPs, strings.Split(cfg.LogRequestExcludeHeadersList, ",")) + defaultLogMiddleware.DisableRequestSuccessLog = cfg.DisableRequestSuccessLog + + defaultHTTPMiddleware := []middleware.Interface{ + middleware.Tracer{ + RouteMatcher: router, + SourceIPs: sourceIPs, + }, + defaultLogMiddleware, + middleware.Instrument{ + RouteMatcher: router, + Duration: metrics.RequestDuration, + RequestBodySize: metrics.ReceivedMessageSize, + ResponseBodySize: metrics.SentMessageSize, + InflightRequests: metrics.InflightRequests, + }, + } + var httpMiddleware []middleware.Interface + if cfg.DoNotAddDefaultHTTPMiddleware { + httpMiddleware = cfg.HTTPMiddleware + } else { + httpMiddleware = append(defaultHTTPMiddleware, cfg.HTTPMiddleware...) + } + + httpServer := &http.Server{ + ReadTimeout: cfg.HTTPServerReadTimeout, + WriteTimeout: cfg.HTTPServerWriteTimeout, + IdleTimeout: cfg.HTTPServerIdleTimeout, + Handler: middleware.Merge(httpMiddleware...).Wrap(router), + } + if httpTLSConfig != nil { + httpServer.TLSConfig = httpTLSConfig + } + + handler := cfg.SignalHandler + if handler == nil { + handler = signals.NewHandler(logger) + } + + return &Server{ + cfg: cfg, + httpListener: httpListener, + grpcListener: grpcListener, + grpcOnHTTPListener: grpcOnHTTPListener, + handler: handler, + grpchttpmux: grpchttpmux, + + HTTP: router, + HTTPServer: httpServer, + GRPC: grpcServer, + GRPCOnHTTPServer: grpcOnHTTPServer, + Log: logger, + Registerer: cfg.registererOrDefault(), + Gatherer: gatherer, + }, nil +} + +// RegisterInstrumentation on the given router. +func RegisterInstrumentation(router *mux.Router) { + RegisterInstrumentationWithGatherer(router, prometheus.DefaultGatherer) +} + +// RegisterInstrumentationWithGatherer on the given router. +func RegisterInstrumentationWithGatherer(router *mux.Router, gatherer prometheus.Gatherer) { + router.Handle("/metrics", promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{ + EnableOpenMetrics: true, + })) + router.PathPrefix("/debug/pprof").Handler(http.DefaultServeMux) +} + +// Run the server; blocks until SIGTERM (if signal handling is enabled), an error is received, or Stop() is called. +func (s *Server) Run() error { + errChan := make(chan error, 1) + + // Wait for a signal + go func() { + s.handler.Loop() + select { + case errChan <- nil: + default: + } + }() + + go func() { + var err error + if s.HTTPServer.TLSConfig == nil { + err = s.HTTPServer.Serve(s.httpListener) + } else { + err = s.HTTPServer.ServeTLS(s.httpListener, s.cfg.HTTPTLSConfig.TLSCertPath, s.cfg.HTTPTLSConfig.TLSKeyPath) + } + if err == http.ErrServerClosed { + err = nil + } + + select { + case errChan <- err: + default: + } + }() + + // Setup gRPC server + // for HTTP over gRPC, ensure we don't double-count the middleware + httpgrpc.RegisterHTTPServer(s.GRPC, httpgrpc_server.NewServer(s.HTTP)) + + go func() { + err := s.GRPC.Serve(s.grpcListener) + handleGRPCError(err, errChan) + }() + + // grpchttpmux will only be set if grpchttpmux RouteHTTPToGRPC is set + if s.grpchttpmux != nil { + go func() { + err := s.grpchttpmux.Serve() + handleGRPCError(err, errChan) + }() + go func() { + err := s.GRPCOnHTTPServer.Serve(s.grpcOnHTTPListener) + handleGRPCError(err, errChan) + }() + } + + return <-errChan +} + +// handleGRPCError consolidates GRPC Server error handling by sending +// any error to errChan except for grpc.ErrServerStopped which is ignored. +func handleGRPCError(err error, errChan chan error) { + if err == grpc.ErrServerStopped { + err = nil + } + + select { + case errChan <- err: + default: + } +} + +// HTTPListenAddr exposes `net.Addr` that `Server` is listening to for HTTP connections. +func (s *Server) HTTPListenAddr() net.Addr { + return s.httpListener.Addr() + +} + +// GRPCListenAddr exposes `net.Addr` that `Server` is listening to for GRPC connections. +func (s *Server) GRPCListenAddr() net.Addr { + return s.grpcListener.Addr() +} + +// Stop unblocks Run(). +func (s *Server) Stop() { + s.handler.Stop() +} + +// Shutdown the server, gracefully. Should be defered after New(). +func (s *Server) Shutdown() { + ctx, cancel := context.WithTimeout(context.Background(), s.cfg.ServerGracefulShutdownTimeout) + defer cancel() // releases resources if httpServer.Shutdown completes before timeout elapses + + _ = s.HTTPServer.Shutdown(ctx) + s.GRPC.GracefulStop() +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 000000000..fbb5e968a --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,726 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/server/server_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package server + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "flag" + "io" + "net/http" + "os" + "os/exec" + "strconv" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus/testutil" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + + protobuf "github.com/golang/protobuf/ptypes/empty" + "github.com/gorilla/mux" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + + "github.com/grafana/dskit/httpgrpc" + "github.com/grafana/dskit/log" + "github.com/grafana/dskit/middleware" +) + +type FakeServer struct{} + +func (f FakeServer) FailWithError(_ context.Context, _ *protobuf.Empty) (*protobuf.Empty, error) { + return nil, errors.New("test error") +} + +func (f FakeServer) FailWithHTTPError(_ context.Context, req *FailWithHTTPErrorRequest) (*protobuf.Empty, error) { + return nil, httpgrpc.Errorf(int(req.Code), strconv.Itoa(int(req.Code))) +} + +func (f FakeServer) Succeed(_ context.Context, _ *protobuf.Empty) (*protobuf.Empty, error) { + return &protobuf.Empty{}, nil +} + +func (f FakeServer) Sleep(ctx context.Context, _ *protobuf.Empty) (*protobuf.Empty, error) { + err := cancelableSleep(ctx, 10*time.Second) + return &protobuf.Empty{}, err +} + +func (f FakeServer) StreamSleep(_ *protobuf.Empty, stream FakeServer_StreamSleepServer) error { + for x := 0; x < 100; x++ { + time.Sleep(time.Second / 100.0) + if err := stream.Send(&protobuf.Empty{}); err != nil { + return err + } + } + return nil +} + +func cancelableSleep(ctx context.Context, sleep time.Duration) error { + select { + case <-time.After(sleep): + case <-ctx.Done(): + } + return ctx.Err() +} + +func TestTCPv4Network(t *testing.T) { + cfg := Config{ + HTTPListenNetwork: NetworkTCPV4, + HTTPListenAddress: "localhost", + HTTPListenPort: 9290, + GRPCListenNetwork: NetworkTCPV4, + GRPCListenAddress: "localhost", + GRPCListenPort: 9291, + } + t.Run("http", func(t *testing.T) { + cfg.MetricsNamespace = "testing_http_tcp4" + srv, err := New(cfg) + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + errChan <- srv.Run() + }() + + require.NoError(t, srv.httpListener.Close()) + require.NotNil(t, <-errChan) + + // So that address is freed for further tests. + srv.GRPC.Stop() + }) + + t.Run("grpc", func(t *testing.T) { + cfg.MetricsNamespace = "testing_grpc_tcp4" + srv, err := New(cfg) + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + errChan <- srv.Run() + }() + + require.NoError(t, srv.grpcListener.Close()) + require.NotNil(t, <-errChan) + }) +} + +// Ensure that http and grpc servers work with no overrides to config +// (except http port because an ordinary user can't bind to default port 80) +func TestDefaultAddresses(t *testing.T) { + var cfg Config + cfg.RegisterFlags(flag.NewFlagSet("", flag.ExitOnError)) + cfg.HTTPListenPort = 9090 + cfg.MetricsNamespace = "testing_addresses" + + server, err := New(cfg) + require.NoError(t, err) + + fakeServer := FakeServer{} + RegisterFakeServerServer(server.GRPC, fakeServer) + + server.HTTP.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(204) + }) + + go func() { + require.NoError(t, server.Run()) + }() + defer server.Shutdown() + + conn, err := grpc.Dial("localhost:9095", grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + empty := protobuf.Empty{} + client := NewFakeServerClient(conn) + _, err = client.Succeed(context.Background(), &empty) + require.NoError(t, err) + + req, err := http.NewRequest("GET", "http://127.0.0.1:9090/test", nil) + require.NoError(t, err) + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) +} + +func TestErrorInstrumentationMiddleware(t *testing.T) { + var cfg Config + cfg.RegisterFlags(flag.NewFlagSet("", flag.ExitOnError)) + cfg.HTTPListenPort = 9090 // can't use 80 as ordinary user + cfg.GRPCListenAddress = "localhost" + cfg.GRPCListenPort = 1234 + server, err := New(cfg) + require.NoError(t, err) + + fakeServer := FakeServer{} + RegisterFakeServerServer(server.GRPC, fakeServer) + + server.HTTP.HandleFunc("/succeed", func(w http.ResponseWriter, r *http.Request) { + }) + server.HTTP.HandleFunc("/error500", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + }) + server.HTTP.HandleFunc("/sleep10", func(w http.ResponseWriter, r *http.Request) { + _ = cancelableSleep(r.Context(), time.Second*10) + }) + server.HTTP.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + + go func() { + require.NoError(t, server.Run()) + }() + + conn, err := grpc.Dial("localhost:1234", grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + empty := protobuf.Empty{} + client := NewFakeServerClient(conn) + res, err := client.Succeed(context.Background(), &empty) + require.NoError(t, err) + require.EqualValues(t, &empty, res) + + res, err = client.FailWithError(context.Background(), &empty) + require.Nil(t, res) + require.Error(t, err) + + s, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, "test error", s.Message()) + + res, err = client.FailWithHTTPError(context.Background(), &FailWithHTTPErrorRequest{Code: http.StatusPaymentRequired}) + require.Nil(t, res) + errResp, ok := httpgrpc.HTTPResponseFromError(err) + require.True(t, ok) + require.Equal(t, int32(http.StatusPaymentRequired), errResp.Code) + require.Equal(t, "402", string(errResp.Body)) + + callThenCancel := func(f func(ctx context.Context) error) error { + ctx, cancel := context.WithCancel(context.Background()) + errChan := make(chan error, 1) + go func() { + errChan <- f(ctx) + }() + time.Sleep(50 * time.Millisecond) // allow the call to reach the handler + cancel() + return <-errChan + } + + err = callThenCancel(func(ctx context.Context) error { + _, err = client.Sleep(ctx, &empty) + return err + }) + require.Error(t, err, context.Canceled) + + err = callThenCancel(func(ctx context.Context) error { + _, err = client.StreamSleep(ctx, &empty) + return err + }) + require.NoError(t, err) // canceling a streaming fn doesn't generate an error + + // Now test the HTTP versions of the functions + { + req, err := http.NewRequest("GET", "http://127.0.0.1:9090/succeed", nil) + require.NoError(t, err) + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) + } + { + req, err := http.NewRequest("GET", "http://127.0.0.1:9090/error500", nil) + require.NoError(t, err) + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) + } + { + req, err := http.NewRequest("GET", "http://127.0.0.1:9090/notfound", nil) + require.NoError(t, err) + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) + } + { + req, err := http.NewRequest("GET", "http://127.0.0.1:9090/sleep10", nil) + require.NoError(t, err) + err = callThenCancel(func(ctx context.Context) error { + _, err = http.DefaultClient.Do(req.WithContext(ctx)) + return err + }) + require.Error(t, err, context.Canceled) + } + + require.NoError(t, conn.Close()) + server.Shutdown() + + metrics, err := prometheus.DefaultGatherer.Gather() + require.NoError(t, err) + + statuses := map[string]string{} + for _, family := range metrics { + if *family.Name == "request_duration_seconds" { + for _, metric := range family.Metric { + var route, statusCode string + for _, label := range metric.GetLabel() { + switch label.GetName() { + case "status_code": + statusCode = label.GetValue() + case "route": + route = label.GetValue() + } + } + statuses[route] = statusCode + } + } + } + require.Equal(t, map[string]string{ + "/server.FakeServer/FailWithError": "error", + "/server.FakeServer/FailWithHTTPError": "402", + "/server.FakeServer/Sleep": "cancel", + "/server.FakeServer/StreamSleep": "cancel", + "/server.FakeServer/Succeed": "success", + "error500": "500", + "sleep10": "200", + "succeed": "200", + "notfound": "404", + }, statuses) +} + +func TestHTTPInstrumentationMetrics(t *testing.T) { + reg := prometheus.NewRegistry() + prometheus.DefaultRegisterer = reg + prometheus.DefaultGatherer = reg + + var cfg Config + cfg.RegisterFlags(flag.NewFlagSet("", flag.ExitOnError)) + cfg.HTTPListenPort = 9090 // can't use 80 as ordinary user + cfg.GRPCListenAddress = "localhost" + cfg.GRPCListenPort = 1234 + server, err := New(cfg) + require.NoError(t, err) + + server.HTTP.HandleFunc("/succeed", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("OK")) + }) + server.HTTP.HandleFunc("/error500", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + }) + server.HTTP.HandleFunc("/sleep10", func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) // Consume body, otherwise it's not counted. + _ = cancelableSleep(r.Context(), time.Second*10) + }) + + go func() { + require.NoError(t, server.Run()) + }() + + callThenCancel := func(f func(ctx context.Context) error) error { + ctx, cancel := context.WithCancel(context.Background()) + errChan := make(chan error, 1) + go func() { + errChan <- f(ctx) + }() + time.Sleep(50 * time.Millisecond) // allow the call to reach the handler + cancel() + return <-errChan + } + + // Now test the HTTP versions of the functions + { + req, err := http.NewRequest("GET", "http://127.0.0.1:9090/succeed", nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "OK", string(body)) + } + { + req, err := http.NewRequest("GET", "http://127.0.0.1:9090/error500", nil) + require.NoError(t, err) + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) + } + { + req, err := http.NewRequest("POST", "http://127.0.0.1:9090/sleep10", bytes.NewReader([]byte("Body"))) + require.NoError(t, err) + err = callThenCancel(func(ctx context.Context) error { + _, err = http.DefaultClient.Do(req.WithContext(ctx)) + return err + }) + require.Error(t, err, context.Canceled) + } + + server.Shutdown() + + require.NoError(t, testutil.GatherAndCompare(prometheus.DefaultGatherer, bytes.NewBufferString(` + # HELP inflight_requests Current number of inflight requests. + # TYPE inflight_requests gauge + inflight_requests{method="POST",route="sleep10"} 0 + inflight_requests{method="GET",route="succeed"} 0 + inflight_requests{method="GET",route="error500"} 0 + + # HELP request_message_bytes Size (in bytes) of messages received in the request. + # TYPE request_message_bytes histogram + request_message_bytes_bucket{method="GET",route="error500",le="1.048576e+06"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="2.62144e+06"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="5.24288e+06"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="1.048576e+07"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="2.62144e+07"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="5.24288e+07"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="1.048576e+08"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="2.62144e+08"} 1 + request_message_bytes_bucket{method="GET",route="error500",le="+Inf"} 1 + request_message_bytes_sum{method="GET",route="error500"} 0 + request_message_bytes_count{method="GET",route="error500"} 1 + + request_message_bytes_bucket{method="POST",route="sleep10",le="1.048576e+06"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="2.62144e+06"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="5.24288e+06"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="1.048576e+07"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="2.62144e+07"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="5.24288e+07"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="1.048576e+08"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="2.62144e+08"} 1 + request_message_bytes_bucket{method="POST",route="sleep10",le="+Inf"} 1 + request_message_bytes_sum{method="POST",route="sleep10"} 4 + request_message_bytes_count{method="POST",route="sleep10"} 1 + + request_message_bytes_bucket{method="GET",route="succeed",le="1.048576e+06"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="2.62144e+06"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="5.24288e+06"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="1.048576e+07"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="2.62144e+07"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="5.24288e+07"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="1.048576e+08"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="2.62144e+08"} 1 + request_message_bytes_bucket{method="GET",route="succeed",le="+Inf"} 1 + request_message_bytes_sum{method="GET",route="succeed"} 0 + request_message_bytes_count{method="GET",route="succeed"} 1 + + # HELP response_message_bytes Size (in bytes) of messages sent in response. + # TYPE response_message_bytes histogram + response_message_bytes_bucket{method="GET",route="error500",le="1.048576e+06"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="2.62144e+06"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="5.24288e+06"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="1.048576e+07"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="2.62144e+07"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="5.24288e+07"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="1.048576e+08"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="2.62144e+08"} 1 + response_message_bytes_bucket{method="GET",route="error500",le="+Inf"} 1 + response_message_bytes_sum{method="GET",route="error500"} 0 + response_message_bytes_count{method="GET",route="error500"} 1 + + response_message_bytes_bucket{method="POST",route="sleep10",le="1.048576e+06"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="2.62144e+06"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="5.24288e+06"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="1.048576e+07"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="2.62144e+07"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="5.24288e+07"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="1.048576e+08"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="2.62144e+08"} 1 + response_message_bytes_bucket{method="POST",route="sleep10",le="+Inf"} 1 + response_message_bytes_sum{method="POST",route="sleep10"} 0 + response_message_bytes_count{method="POST",route="sleep10"} 1 + + response_message_bytes_bucket{method="GET",route="succeed",le="1.048576e+06"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="2.62144e+06"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="5.24288e+06"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="1.048576e+07"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="2.62144e+07"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="5.24288e+07"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="1.048576e+08"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="2.62144e+08"} 1 + response_message_bytes_bucket{method="GET",route="succeed",le="+Inf"} 1 + response_message_bytes_sum{method="GET",route="succeed"} 2 + response_message_bytes_count{method="GET",route="succeed"} 1 + + # HELP tcp_connections Current number of accepted TCP connections. + # TYPE tcp_connections gauge + tcp_connections{protocol="http"} 0 + tcp_connections{protocol="grpc"} 0 + `), "request_message_bytes", "response_message_bytes", "inflight_requests", "tcp_connections")) +} + +func TestRunReturnsError(t *testing.T) { + cfg := Config{ + HTTPListenNetwork: DefaultNetwork, + HTTPListenAddress: "localhost", + HTTPListenPort: 9090, + GRPCListenNetwork: DefaultNetwork, + GRPCListenAddress: "localhost", + GRPCListenPort: 9191, + } + t.Run("http", func(t *testing.T) { + cfg.MetricsNamespace = "testing_http" + srv, err := New(cfg) + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + errChan <- srv.Run() + }() + + require.NoError(t, srv.httpListener.Close()) + require.NotNil(t, <-errChan) + + // So that address is freed for further tests. + srv.GRPC.Stop() + }) + + t.Run("grpc", func(t *testing.T) { + cfg.MetricsNamespace = "testing_grpc" + srv, err := New(cfg) + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + errChan <- srv.Run() + }() + + require.NoError(t, srv.grpcListener.Close()) + require.NotNil(t, <-errChan) + }) +} + +// Test to see what the logging of a 500 error looks like +func TestMiddlewareLogging(t *testing.T) { + var level log.Level + require.NoError(t, level.Set("info")) + cfg := Config{ + HTTPListenNetwork: DefaultNetwork, + HTTPListenAddress: "localhost", + HTTPListenPort: 9192, + GRPCListenNetwork: DefaultNetwork, + GRPCListenAddress: "localhost", + HTTPMiddleware: []middleware.Interface{middleware.Logging}, + MetricsNamespace: "testing_logging", + LogLevel: level, + DoNotAddDefaultHTTPMiddleware: true, + Router: &mux.Router{}, + } + server, err := New(cfg) + require.NoError(t, err) + + server.HTTP.HandleFunc("/error500", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + }) + + go func() { + require.NoError(t, server.Run()) + }() + defer server.Shutdown() + + req, err := http.NewRequest("GET", "http://127.0.0.1:9192/error500", nil) + require.NoError(t, err) + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) +} + +func TestTLSServer(t *testing.T) { + var level log.Level + require.NoError(t, level.Set("info")) + + cmd := exec.Command("bash", "certs/genCerts.sh", "certs", "1") + err := cmd.Run() + require.NoError(t, err) + + cfg := Config{ + HTTPListenNetwork: DefaultNetwork, + HTTPListenAddress: "localhost", + HTTPListenPort: 9193, + HTTPTLSConfig: TLSConfig{ + TLSCertPath: "certs/server.crt", + TLSKeyPath: "certs/server.key", + ClientAuth: "RequireAndVerifyClientCert", + ClientCAs: "certs/root.crt", + }, + GRPCTLSConfig: TLSConfig{ + TLSCertPath: "certs/server.crt", + TLSKeyPath: "certs/server.key", + ClientAuth: "VerifyClientCertIfGiven", + ClientCAs: "certs/root.crt", + }, + MetricsNamespace: "testing_tls", + GRPCListenNetwork: DefaultNetwork, + GRPCListenAddress: "localhost", + GRPCListenPort: 9194, + } + server, err := New(cfg) + require.NoError(t, err) + + server.HTTP.HandleFunc("/testhttps", func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("Hello World!")) + require.NoError(t, err) + }) + + fakeServer := FakeServer{} + RegisterFakeServerServer(server.GRPC, fakeServer) + + go func() { + require.NoError(t, server.Run()) + }() + defer server.Shutdown() + + clientCert, err := tls.LoadX509KeyPair("certs/client.crt", "certs/client.key") + require.NoError(t, err) + + caCert, err := os.ReadFile(cfg.HTTPTLSConfig.ClientCAs) + require.NoError(t, err) + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{clientCert}, + RootCAs: caCertPool, + } + tr := &http.Transport{ + TLSClientConfig: tlsConfig, + } + + client := &http.Client{Transport: tr} + res, err := client.Get("https://localhost:9193/testhttps") + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, res.StatusCode, http.StatusOK) + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + expected := []byte("Hello World!") + require.Equal(t, expected, body) + + conn, err := grpc.Dial("localhost:9194", grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + require.NoError(t, err) + defer conn.Close() + + empty := protobuf.Empty{} + grpcClient := NewFakeServerClient(conn) + grpcRes, err := grpcClient.Succeed(context.Background(), &empty) + require.NoError(t, err) + require.EqualValues(t, &empty, grpcRes) +} + +type FakeLogger struct { + sourceIPs string +} + +func (f *FakeLogger) Debugf(_ string, _ ...interface{}) {} +func (f *FakeLogger) Debugln(_ ...interface{}) {} + +func (f *FakeLogger) Infof(_ string, _ ...interface{}) {} +func (f *FakeLogger) Infoln(_ ...interface{}) {} + +func (f *FakeLogger) Errorf(_ string, _ ...interface{}) {} +func (f *FakeLogger) Errorln(_ ...interface{}) {} + +func (f *FakeLogger) Warnf(_ string, _ ...interface{}) {} +func (f *FakeLogger) Warnln(_ ...interface{}) {} + +func (f *FakeLogger) WithField(key string, value interface{}) log.Interface { + if key == "sourceIPs" { + f.sourceIPs = value.(string) + } + + return f +} + +func (f *FakeLogger) WithFields(_ log.Fields) log.Interface { + return f +} + +func TestLogSourceIPs(t *testing.T) { + var level log.Level + require.NoError(t, level.Set("debug")) + fake := FakeLogger{} + cfg := Config{ + HTTPListenNetwork: DefaultNetwork, + HTTPListenAddress: "localhost", + HTTPListenPort: 9195, + GRPCListenNetwork: DefaultNetwork, + GRPCListenAddress: "localhost", + HTTPMiddleware: []middleware.Interface{middleware.Logging}, + MetricsNamespace: "testing_mux", + LogLevel: level, + Log: &fake, + LogSourceIPs: true, + } + server, err := New(cfg) + require.NoError(t, err) + + server.HTTP.HandleFunc("/error500", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + }) + + go func() { + require.NoError(t, server.Run()) + }() + defer server.Shutdown() + + require.Empty(t, fake.sourceIPs) + + req, err := http.NewRequest("GET", "http://127.0.0.1:9195/error500", nil) + require.NoError(t, err) + _, err = http.DefaultClient.Do(req) + require.NoError(t, err) + + require.Equal(t, fake.sourceIPs, "127.0.0.1") +} + +func TestStopWithDisabledSignalHandling(t *testing.T) { + cfg := Config{ + HTTPListenNetwork: DefaultNetwork, + HTTPListenAddress: "localhost", + HTTPListenPort: 9198, + GRPCListenNetwork: DefaultNetwork, + GRPCListenAddress: "localhost", + GRPCListenPort: 9199, + } + + var test = func(t *testing.T, metricsNamespace string, handler SignalHandler) { + cfg.SignalHandler = handler + cfg.MetricsNamespace = metricsNamespace + srv, err := New(cfg) + require.NoError(t, err) + + errChan := make(chan error, 1) + go func() { + errChan <- srv.Run() + }() + + srv.Stop() + require.Nil(t, <-errChan) + + // So that addresses is freed for further tests. + srv.Shutdown() + } + + t.Run("signals_enabled", func(t *testing.T) { + test(t, "signals_enabled", nil) + }) + + t.Run("signals_disabled", func(t *testing.T) { + test(t, "signals_disabled", dummyHandler{quit: make(chan struct{})}) + }) +} + +type dummyHandler struct { + quit chan struct{} +} + +func (dh dummyHandler) Loop() { + <-dh.quit +} + +func (dh dummyHandler) Stop() { + close(dh.quit) +} diff --git a/server/tls_config.go b/server/tls_config.go new file mode 100644 index 000000000..128c6d04f --- /dev/null +++ b/server/tls_config.go @@ -0,0 +1,59 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/server/tls_config.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package server + +import ( + "crypto/tls" + fmt "fmt" + "strings" + + "github.com/prometheus/exporter-toolkit/web" +) + +// Collect all cipher suite names and IDs recognized by Go, including insecure ones. +func allCiphers() map[string]web.Cipher { + acceptedCiphers := make(map[string]web.Cipher) + for _, suite := range tls.CipherSuites() { + acceptedCiphers[suite.Name] = web.Cipher(suite.ID) + } + for _, suite := range tls.InsecureCipherSuites() { + acceptedCiphers[suite.Name] = web.Cipher(suite.ID) + } + return acceptedCiphers +} + +func stringToCipherSuites(s string) ([]web.Cipher, error) { + if s == "" { + return nil, nil + } + ciphersSlice := []web.Cipher{} + possibleCiphers := allCiphers() + for _, cipher := range strings.Split(s, ",") { + intValue, ok := possibleCiphers[cipher] + if !ok { + return nil, fmt.Errorf("cipher suite %q not recognized", cipher) + } + ciphersSlice = append(ciphersSlice, intValue) + } + return ciphersSlice, nil +} + +// Using the same names that Kubernetes does +var tlsVersions = map[string]uint16{ + "VersionTLS10": tls.VersionTLS10, + "VersionTLS11": tls.VersionTLS11, + "VersionTLS12": tls.VersionTLS12, + "VersionTLS13": tls.VersionTLS13, +} + +func stringToTLSVersion(s string) (web.TLSVersion, error) { + if s == "" { + return 0, nil + } + if version, ok := tlsVersions[s]; ok { + return web.TLSVersion(version), nil + } + return 0, fmt.Errorf("TLS version %q not recognized", s) +} diff --git a/server/tls_config_test.go b/server/tls_config_test.go new file mode 100644 index 000000000..0b7e88162 --- /dev/null +++ b/server/tls_config_test.go @@ -0,0 +1,63 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/server/tls_config_test.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package server + +import ( + "crypto/tls" + "testing" + + "github.com/prometheus/exporter-toolkit/web" + "github.com/stretchr/testify/require" +) + +func Test_stringToCipherSuites(t *testing.T) { + tests := []struct { + name string + arg string + want []web.Cipher + wantErr bool + }{ + {name: "blank", arg: "", want: nil}, + {name: "bad", arg: "not-a-cipher", wantErr: true}, + {name: "one", arg: "TLS_AES_256_GCM_SHA384", want: []web.Cipher{web.Cipher(tls.TLS_AES_256_GCM_SHA384)}}, + {name: "two", arg: "TLS_AES_256_GCM_SHA384,TLS_CHACHA20_POLY1305_SHA256", + want: []web.Cipher{web.Cipher(tls.TLS_AES_256_GCM_SHA384), web.Cipher(tls.TLS_CHACHA20_POLY1305_SHA256)}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := stringToCipherSuites(tt.arg) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.want, got) + }) + } +} + +func Test_stringToTLSVersion(t *testing.T) { + tests := []struct { + name string + arg string + want web.TLSVersion + wantErr bool + }{ + {name: "blank", arg: "", want: 0}, + {name: "bad", arg: "not-a-version", wantErr: true}, + {name: "VersionTLS12", arg: "VersionTLS12", want: tls.VersionTLS12}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := stringToTLSVersion(tt.arg) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.want, got) + }) + } +} diff --git a/signals/signals.go b/signals/signals.go new file mode 100644 index 000000000..75609a745 --- /dev/null +++ b/signals/signals.go @@ -0,0 +1,77 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/signals/signals.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package signals + +import ( + "os" + "os/signal" + "runtime" + "syscall" + + "github.com/grafana/dskit/log" +) + +// SignalReceiver represents a subsystem/server/... that can be stopped or +// queried about the status with a signal +type SignalReceiver interface { + Stop() error +} + +// Handler handles signals, can be interrupted. +// On SIGINT or SIGTERM it will exit, on SIGQUIT it +// will dump goroutine stacks to the Logger. +type Handler struct { + log log.Interface + receivers []SignalReceiver + quit chan struct{} +} + +// NewHandler makes a new Handler. +func NewHandler(log log.Interface, receivers ...SignalReceiver) *Handler { + return &Handler{ + log: log, + receivers: receivers, + quit: make(chan struct{}), + } +} + +// Stop the handler +func (h *Handler) Stop() { + close(h.quit) +} + +// Loop handles signals. +func (h *Handler) Loop() { + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM) + defer signal.Stop(sigs) + buf := make([]byte, 1<<20) + for { + select { + case <-h.quit: + h.log.Infof("=== Handler.Stop()'d ===") + return + case sig := <-sigs: + switch sig { + case syscall.SIGINT, syscall.SIGTERM: + h.log.Infof("=== received SIGINT/SIGTERM ===\n*** exiting") + for _, subsystem := range h.receivers { + _ = subsystem.Stop() + } + return + case syscall.SIGQUIT: + stacklen := runtime.Stack(buf, true) + h.log.Infof("=== received SIGQUIT ===\n*** goroutine dump...\n%s\n*** end", buf[:stacklen]) + } + } + } +} + +// SignalHandlerLoop blocks until it receives a SIGINT, SIGTERM or SIGQUIT. +// For SIGINT and SIGTERM, it exits; for SIGQUIT is print a goroutine stack +// dump. +func SignalHandlerLoop(log log.Interface, ss ...SignalReceiver) { + NewHandler(log, ss...).Loop() +} diff --git a/spanlogger/spanlogger.go b/spanlogger/spanlogger.go index 7a7a4fb75..9a063e0a4 100644 --- a/spanlogger/spanlogger.go +++ b/spanlogger/spanlogger.go @@ -8,7 +8,8 @@ import ( opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" otlog "github.com/opentracing/opentracing-go/log" - "github.com/weaveworks/common/tracing" + + "github.com/grafana/dskit/tracing" ) type loggerCtxMarker struct{} diff --git a/spanlogger/spanlogger_test.go b/spanlogger/spanlogger_test.go index cffd7ca77..2ab9299e6 100644 --- a/spanlogger/spanlogger_test.go +++ b/spanlogger/spanlogger_test.go @@ -10,7 +10,8 @@ import ( "github.com/opentracing/opentracing-go/mocktracer" "github.com/pkg/errors" "github.com/stretchr/testify/require" - "github.com/weaveworks/common/user" + + "github.com/grafana/dskit/user" ) func TestSpanLogger_Log(t *testing.T) { diff --git a/tenant/resolver.go b/tenant/resolver.go index f0fd8abfe..aa19d75bb 100644 --- a/tenant/resolver.go +++ b/tenant/resolver.go @@ -6,7 +6,7 @@ import ( "net/http" "strings" - "github.com/weaveworks/common/user" + "github.com/grafana/dskit/user" ) var defaultResolver Resolver = NewSingleResolver() diff --git a/tenant/resolver_test.go b/tenant/resolver_test.go index 4d2da2416..c71c013fa 100644 --- a/tenant/resolver_test.go +++ b/tenant/resolver_test.go @@ -5,7 +5,8 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/weaveworks/common/user" + + "github.com/grafana/dskit/user" ) func strptr(s string) *string { diff --git a/tenant/tenant.go b/tenant/tenant.go index c7c772648..a5807500e 100644 --- a/tenant/tenant.go +++ b/tenant/tenant.go @@ -7,7 +7,7 @@ import ( "sort" "strings" - "github.com/weaveworks/common/user" + "github.com/grafana/dskit/user" ) var ( diff --git a/tracing/tracing.go b/tracing/tracing.go new file mode 100644 index 000000000..66b3a3cef --- /dev/null +++ b/tracing/tracing.go @@ -0,0 +1,83 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/tracing/tracing.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package tracing + +import ( + "context" + "io" + + "github.com/opentracing/opentracing-go" + "github.com/pkg/errors" + jaeger "github.com/uber/jaeger-client-go" + jaegercfg "github.com/uber/jaeger-client-go/config" + jaegerprom "github.com/uber/jaeger-lib/metrics/prometheus" +) + +// ErrInvalidConfiguration is an error to notify client to provide valid trace report agent or config server +var ( + ErrBlankTraceConfiguration = errors.New("no trace report agent, config server, or collector endpoint specified") +) + +// installJaeger registers Jaeger as the OpenTracing implementation. +func installJaeger(serviceName string, cfg *jaegercfg.Configuration, options ...jaegercfg.Option) (io.Closer, error) { + metricsFactory := jaegerprom.New() + + // put the metricsFactory earlier so provided options can override it + opts := append([]jaegercfg.Option{jaegercfg.Metrics(metricsFactory)}, options...) + + closer, err := cfg.InitGlobalTracer(serviceName, opts...) + if err != nil { + return nil, errors.Wrap(err, "could not initialize jaeger tracer") + } + return closer, nil +} + +// NewFromEnv is a convenience function to allow tracing configuration +// via environment variables +// +// Tracing will be enabled if one (or more) of the following environment variables is used to configure trace reporting: +// - JAEGER_AGENT_HOST +// - JAEGER_SAMPLER_MANAGER_HOST_PORT +func NewFromEnv(serviceName string, options ...jaegercfg.Option) (io.Closer, error) { + cfg, err := jaegercfg.FromEnv() + if err != nil { + return nil, errors.Wrap(err, "could not load jaeger tracer configuration") + } + + if cfg.Sampler.SamplingServerURL == "" && cfg.Reporter.LocalAgentHostPort == "" && cfg.Reporter.CollectorEndpoint == "" { + return nil, ErrBlankTraceConfiguration + } + + return installJaeger(serviceName, cfg, options...) +} + +// ExtractTraceID extracts the trace id, if any from the context. +func ExtractTraceID(ctx context.Context) (string, bool) { + sp := opentracing.SpanFromContext(ctx) + if sp == nil { + return "", false + } + sctx, ok := sp.Context().(jaeger.SpanContext) + if !ok { + return "", false + } + + return sctx.TraceID().String(), true +} + +// ExtractSampledTraceID works like ExtractTraceID but the returned bool is only +// true if the returned trace id is sampled. +func ExtractSampledTraceID(ctx context.Context) (string, bool) { + sp := opentracing.SpanFromContext(ctx) + if sp == nil { + return "", false + } + sctx, ok := sp.Context().(jaeger.SpanContext) + if !ok { + return "", false + } + + return sctx.TraceID().String(), sctx.IsSampled() +} diff --git a/user/grpc.go b/user/grpc.go new file mode 100644 index 000000000..201b835ee --- /dev/null +++ b/user/grpc.go @@ -0,0 +1,56 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/user/grpc.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package user + +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +// ExtractFromGRPCRequest extracts the user ID from the request metadata and returns +// the user ID and a context with the user ID injected. +func ExtractFromGRPCRequest(ctx context.Context) (string, context.Context, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", ctx, ErrNoOrgID + } + + orgIDs, ok := md[lowerOrgIDHeaderName] + if !ok || len(orgIDs) != 1 { + return "", ctx, ErrNoOrgID + } + + return orgIDs[0], InjectOrgID(ctx, orgIDs[0]), nil +} + +// InjectIntoGRPCRequest injects the orgID from the context into the request metadata. +func InjectIntoGRPCRequest(ctx context.Context) (context.Context, error) { + orgID, err := ExtractOrgID(ctx) + if err != nil { + return ctx, err + } + + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + md = metadata.New(map[string]string{}) + } + newCtx := ctx + if orgIDs, ok := md[lowerOrgIDHeaderName]; ok { + if len(orgIDs) == 1 { + if orgIDs[0] != orgID { + return ctx, ErrDifferentOrgIDPresent + } + } else { + return ctx, ErrTooManyOrgIDs + } + } else { + md = md.Copy() + md[lowerOrgIDHeaderName] = []string{orgID} + newCtx = metadata.NewOutgoingContext(ctx, md) + } + + return newCtx, nil +} diff --git a/user/http.go b/user/http.go new file mode 100644 index 000000000..ca015b36d --- /dev/null +++ b/user/http.go @@ -0,0 +1,70 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/user/http.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package user + +import ( + "context" + "net/http" +) + +const ( + // 'Scope' in the below headers is a legacy from scope as a service. + + // OrgIDHeaderName denotes the OrgID the request has been authenticated as + OrgIDHeaderName = "X-Scope-OrgID" + // UserIDHeaderName denotes the UserID the request has been authenticated as + UserIDHeaderName = "X-Scope-UserID" + + // LowerOrgIDHeaderName as gRPC / HTTP2.0 headers are lowercased. + lowerOrgIDHeaderName = "x-scope-orgid" +) + +// ExtractOrgIDFromHTTPRequest extracts the org ID from the request headers and returns +// the org ID and a context with the org ID embedded. +func ExtractOrgIDFromHTTPRequest(r *http.Request) (string, context.Context, error) { + orgID := r.Header.Get(OrgIDHeaderName) + if orgID == "" { + return "", r.Context(), ErrNoOrgID + } + return orgID, InjectOrgID(r.Context(), orgID), nil +} + +// InjectOrgIDIntoHTTPRequest injects the orgID from the context into the request headers. +func InjectOrgIDIntoHTTPRequest(ctx context.Context, r *http.Request) error { + orgID, err := ExtractOrgID(ctx) + if err != nil { + return err + } + existingID := r.Header.Get(OrgIDHeaderName) + if existingID != "" && existingID != orgID { + return ErrDifferentOrgIDPresent + } + r.Header.Set(OrgIDHeaderName, orgID) + return nil +} + +// ExtractUserIDFromHTTPRequest extracts the org ID from the request headers and returns +// the org ID and a context with the org ID embedded. +func ExtractUserIDFromHTTPRequest(r *http.Request) (string, context.Context, error) { + userID := r.Header.Get(UserIDHeaderName) + if userID == "" { + return "", r.Context(), ErrNoUserID + } + return userID, InjectUserID(r.Context(), userID), nil +} + +// InjectUserIDIntoHTTPRequest injects the userID from the context into the request headers. +func InjectUserIDIntoHTTPRequest(ctx context.Context, r *http.Request) error { + userID, err := ExtractUserID(ctx) + if err != nil { + return err + } + existingID := r.Header.Get(UserIDHeaderName) + if existingID != "" && existingID != userID { + return ErrDifferentUserIDPresent + } + r.Header.Set(UserIDHeaderName, userID) + return nil +} diff --git a/user/id.go b/user/id.go new file mode 100644 index 000000000..2396787c2 --- /dev/null +++ b/user/id.go @@ -0,0 +1,58 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/user/id.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package user + +import ( + "context" + + "github.com/grafana/dskit/errors" +) + +type contextKey int + +const ( + // Keys used in contexts to find the org or user ID + orgIDContextKey contextKey = 0 + userIDContextKey contextKey = 1 +) + +// Errors that we return +const ( + ErrNoOrgID = errors.Error("no org id") + ErrDifferentOrgIDPresent = errors.Error("different org ID already present") + ErrTooManyOrgIDs = errors.Error("multiple org IDs present") + + ErrNoUserID = errors.Error("no user id") + ErrDifferentUserIDPresent = errors.Error("different user ID already present") + ErrTooManyUserIDs = errors.Error("multiple user IDs present") +) + +// ExtractOrgID gets the org ID from the context. +func ExtractOrgID(ctx context.Context) (string, error) { + orgID, ok := ctx.Value(orgIDContextKey).(string) + if !ok { + return "", ErrNoOrgID + } + return orgID, nil +} + +// InjectOrgID returns a derived context containing the org ID. +func InjectOrgID(ctx context.Context, orgID string) context.Context { + return context.WithValue(ctx, interface{}(orgIDContextKey), orgID) +} + +// ExtractUserID gets the user ID from the context. +func ExtractUserID(ctx context.Context) (string, error) { + userID, ok := ctx.Value(userIDContextKey).(string) + if !ok { + return "", ErrNoUserID + } + return userID, nil +} + +// InjectUserID returns a derived context containing the user ID. +func InjectUserID(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, interface{}(userIDContextKey), userID) +} diff --git a/user/logging.go b/user/logging.go new file mode 100644 index 000000000..0db1946c8 --- /dev/null +++ b/user/logging.go @@ -0,0 +1,26 @@ +// Provenance-includes-location: https://github.com/weaveworks/common/blob/main/user/logging.go +// Provenance-includes-license: Apache-2.0 +// Provenance-includes-copyright: Weaveworks Ltd. + +package user + +import ( + "context" + + "github.com/grafana/dskit/log" +) + +// LogWith returns user and org information from the context as log fields. +func LogWith(ctx context.Context, log log.Interface) log.Interface { + userID, err := ExtractUserID(ctx) + if err == nil { + log = log.WithField("userID", userID) + } + + orgID, err := ExtractOrgID(ctx) + if err == nil { + log = log.WithField("orgID", orgID) + } + + return log +}