diff --git a/pkg/v1/remote/transport/useragent.go b/pkg/v1/remote/transport/useragent.go index 74a9e71bd..ba1a3bbea 100644 --- a/pkg/v1/remote/transport/useragent.go +++ b/pkg/v1/remote/transport/useragent.go @@ -25,12 +25,13 @@ var ( // -ldflags="-X 'github.com/google/go-containerregistry/pkg/v1/remote/transport.Version=$TAG'" Version string - ggcrVersion = defaultUserAgent + defaultUserAgent string + ggcrVersion = ggcrProduct ) const ( - defaultUserAgent = "go-containerregistry" - moduleName = "github.com/google/go-containerregistry" + ggcrProduct = "go-containerregistry" + moduleName = "github.com/google/go-containerregistry" ) type userAgentTransport struct { @@ -40,7 +41,7 @@ type userAgentTransport struct { func init() { if v := version(); v != "" { - ggcrVersion = fmt.Sprintf("%s/%s", defaultUserAgent, v) + ggcrVersion = fmt.Sprintf("%s/%s", ggcrProduct, v) } } @@ -76,6 +77,10 @@ func version() string { // // User-Agent: crane/v0.1.4 go-containerregistry/v0.1.4 func NewUserAgent(inner http.RoundTripper, ua string) http.RoundTripper { + if ua == "" { + ua = defaultUserAgent + } + // defaultUserAgent might not be set, so check this again. if ua == "" { ua = ggcrVersion } else { @@ -92,3 +97,11 @@ func (ut *userAgentTransport) RoundTrip(in *http.Request) (*http.Response, error in.Header.Set("User-Agent", ut.ua) return ut.inner.RoundTrip(in) } + +// SetDefaultUserAgent sets the global default user agent string. +// Default user agent behavior follows that of [NewUserAgent], in that the resulting +// user agent string will include both the provided user agent and the go-containerregistry +// version. +func SetDefaultUserAgent(ua string) { + defaultUserAgent = ua +} diff --git a/pkg/v1/remote/transport/useragent_test.go b/pkg/v1/remote/transport/useragent_test.go new file mode 100644 index 000000000..5ac7f9fd3 --- /dev/null +++ b/pkg/v1/remote/transport/useragent_test.go @@ -0,0 +1,44 @@ +package transport + +import ( + "testing" +) + +func TestDefaultUserAgent(t *testing.T) { + for _, tc := range []struct { + defaultUA string + ua string + want string + }{ + { + want: "go-containerregistry", + }, + { + defaultUA: "foo", + want: "foo go-containerregistry", + }, + { + ua: "bar", + want: "bar go-containerregistry", + }, + { + defaultUA: "foo", + ua: "bar", + want: "bar go-containerregistry", + }, + } { + t.Run("", func(t *testing.T) { + SetDefaultUserAgent(tc.defaultUA) + t.Cleanup(func() { + SetDefaultUserAgent("") + }) + rt, ok := NewUserAgent(nil, tc.ua).(*userAgentTransport) + if !ok { + t.Fatalf("NewUserAgent returned a %T, want *userAgentTransport", rt) + } + if rt.ua != tc.want { + t.Errorf("want %q, got %q", tc.want, rt.ua) + } + }) + } +}