From 35ed7a06e635c9eb2a2c03c41bdfa215fa4d3e44 Mon Sep 17 00:00:00 2001 From: Anis Elleuch Date: Wed, 6 Dec 2023 20:04:30 -0800 Subject: [PATCH] tier: Add support of service principal to Azure --- .github/workflows/vulncheck.yml | 2 +- tier-azure.go | 23 +++ tier-azure_gen.go | 293 +++++++++++++++++++++++++++++++- tier-azure_gen_test.go | 113 ++++++++++++ 4 files changed, 425 insertions(+), 6 deletions(-) diff --git a/.github/workflows/vulncheck.yml b/.github/workflows/vulncheck.yml index 3ae0b85..4c8ee2b 100644 --- a/.github/workflows/vulncheck.yml +++ b/.github/workflows/vulncheck.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go-version: [ 1.21.4 ] + go-version: [ 1.21.5 ] steps: - name: Check out code into the Go module directory uses: actions/checkout@v3 diff --git a/tier-azure.go b/tier-azure.go index 80c17a1..92f9b7a 100644 --- a/tier-azure.go +++ b/tier-azure.go @@ -21,6 +21,12 @@ package madmin //go:generate msgp -file $GOFILE +type ServicePrincipalAuth struct { + TenantID string `json:",omitempty"` + ClientID string `json:",omitempty"` + ClientSecret string `json:",omitempty"` +} + // TierAzure represents the remote tier configuration for Azure Blob Storage. type TierAzure struct { Endpoint string `json:",omitempty"` @@ -30,11 +36,28 @@ type TierAzure struct { Prefix string `json:",omitempty"` Region string `json:",omitempty"` StorageClass string `json:",omitempty"` + + SPAuth ServicePrincipalAuth `json:",omitempty"` +} + +// IsSPEnabled() returns true if SP related fields are provided +func (ti TierAzure) IsSPEnabled() bool { + return ti.SPAuth.TenantID != "" || ti.SPAuth.ClientID != "" || ti.SPAuth.ClientSecret != "" } // AzureOptions supports NewTierAzure to take variadic options type AzureOptions func(*TierAzure) error +// AzureServicePrincipal helper to supply optional service principal credentials +func AzureServicePrincipal(tenantID, clientID, clientSecret string) func(az *TierAzure) error { + return func(az *TierAzure) error { + az.SPAuth.TenantID = tenantID + az.SPAuth.ClientID = clientID + az.SPAuth.ClientSecret = clientSecret + return nil + } +} + // AzurePrefix helper to supply optional object prefix to NewTierAzure func AzurePrefix(prefix string) func(az *TierAzure) error { return func(az *TierAzure) error { diff --git a/tier-azure_gen.go b/tier-azure_gen.go index 95a50f2..d759a02 100644 --- a/tier-azure_gen.go +++ b/tier-azure_gen.go @@ -6,6 +6,159 @@ import ( "github.com/tinylib/msgp/msgp" ) +// DecodeMsg implements msgp.Decodable +func (z *ServicePrincipalAuth) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "TenantID": + z.TenantID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "TenantID") + return + } + case "ClientID": + z.ClientID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "ClientID") + return + } + case "ClientSecret": + z.ClientSecret, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "ClientSecret") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z ServicePrincipalAuth) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 3 + // write "TenantID" + err = en.Append(0x83, 0xa8, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x44) + if err != nil { + return + } + err = en.WriteString(z.TenantID) + if err != nil { + err = msgp.WrapError(err, "TenantID") + return + } + // write "ClientID" + err = en.Append(0xa8, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44) + if err != nil { + return + } + err = en.WriteString(z.ClientID) + if err != nil { + err = msgp.WrapError(err, "ClientID") + return + } + // write "ClientSecret" + err = en.Append(0xac, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74) + if err != nil { + return + } + err = en.WriteString(z.ClientSecret) + if err != nil { + err = msgp.WrapError(err, "ClientSecret") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z ServicePrincipalAuth) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 3 + // string "TenantID" + o = append(o, 0x83, 0xa8, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x44) + o = msgp.AppendString(o, z.TenantID) + // string "ClientID" + o = append(o, 0xa8, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44) + o = msgp.AppendString(o, z.ClientID) + // string "ClientSecret" + o = append(o, 0xac, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74) + o = msgp.AppendString(o, z.ClientSecret) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *ServicePrincipalAuth) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "TenantID": + z.TenantID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "TenantID") + return + } + case "ClientID": + z.ClientID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "ClientID") + return + } + case "ClientSecret": + z.ClientSecret, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "ClientSecret") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z ServicePrincipalAuth) Msgsize() (s int) { + s = 1 + 9 + msgp.StringPrefixSize + len(z.TenantID) + 9 + msgp.StringPrefixSize + len(z.ClientID) + 13 + msgp.StringPrefixSize + len(z.ClientSecret) + return +} + // DecodeMsg implements msgp.Decodable func (z *TierAzure) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte @@ -66,6 +219,47 @@ func (z *TierAzure) DecodeMsg(dc *msgp.Reader) (err error) { err = msgp.WrapError(err, "StorageClass") return } + case "SPAuth": + var zb0002 uint32 + zb0002, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err, "SPAuth") + return + } + for zb0002 > 0 { + zb0002-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err, "SPAuth") + return + } + switch msgp.UnsafeString(field) { + case "TenantID": + z.SPAuth.TenantID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "SPAuth", "TenantID") + return + } + case "ClientID": + z.SPAuth.ClientID, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "SPAuth", "ClientID") + return + } + case "ClientSecret": + z.SPAuth.ClientSecret, err = dc.ReadString() + if err != nil { + err = msgp.WrapError(err, "SPAuth", "ClientSecret") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err, "SPAuth") + return + } + } + } default: err = dc.Skip() if err != nil { @@ -79,9 +273,9 @@ func (z *TierAzure) DecodeMsg(dc *msgp.Reader) (err error) { // EncodeMsg implements msgp.Encodable func (z *TierAzure) EncodeMsg(en *msgp.Writer) (err error) { - // map header, size 7 + // map header, size 8 // write "Endpoint" - err = en.Append(0x87, 0xa8, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74) + err = en.Append(0x88, 0xa8, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74) if err != nil { return } @@ -150,15 +344,51 @@ func (z *TierAzure) EncodeMsg(en *msgp.Writer) (err error) { err = msgp.WrapError(err, "StorageClass") return } + // write "SPAuth" + err = en.Append(0xa6, 0x53, 0x50, 0x41, 0x75, 0x74, 0x68) + if err != nil { + return + } + // map header, size 3 + // write "TenantID" + err = en.Append(0x83, 0xa8, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x44) + if err != nil { + return + } + err = en.WriteString(z.SPAuth.TenantID) + if err != nil { + err = msgp.WrapError(err, "SPAuth", "TenantID") + return + } + // write "ClientID" + err = en.Append(0xa8, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44) + if err != nil { + return + } + err = en.WriteString(z.SPAuth.ClientID) + if err != nil { + err = msgp.WrapError(err, "SPAuth", "ClientID") + return + } + // write "ClientSecret" + err = en.Append(0xac, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74) + if err != nil { + return + } + err = en.WriteString(z.SPAuth.ClientSecret) + if err != nil { + err = msgp.WrapError(err, "SPAuth", "ClientSecret") + return + } return } // MarshalMsg implements msgp.Marshaler func (z *TierAzure) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) - // map header, size 7 + // map header, size 8 // string "Endpoint" - o = append(o, 0x87, 0xa8, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74) + o = append(o, 0x88, 0xa8, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74) o = msgp.AppendString(o, z.Endpoint) // string "AccountName" o = append(o, 0xab, 0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x4e, 0x61, 0x6d, 0x65) @@ -178,6 +408,18 @@ func (z *TierAzure) MarshalMsg(b []byte) (o []byte, err error) { // string "StorageClass" o = append(o, 0xac, 0x53, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x43, 0x6c, 0x61, 0x73, 0x73) o = msgp.AppendString(o, z.StorageClass) + // string "SPAuth" + o = append(o, 0xa6, 0x53, 0x50, 0x41, 0x75, 0x74, 0x68) + // map header, size 3 + // string "TenantID" + o = append(o, 0x83, 0xa8, 0x54, 0x65, 0x6e, 0x61, 0x6e, 0x74, 0x49, 0x44) + o = msgp.AppendString(o, z.SPAuth.TenantID) + // string "ClientID" + o = append(o, 0xa8, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44) + o = msgp.AppendString(o, z.SPAuth.ClientID) + // string "ClientSecret" + o = append(o, 0xac, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74) + o = msgp.AppendString(o, z.SPAuth.ClientSecret) return } @@ -241,6 +483,47 @@ func (z *TierAzure) UnmarshalMsg(bts []byte) (o []byte, err error) { err = msgp.WrapError(err, "StorageClass") return } + case "SPAuth": + var zb0002 uint32 + zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err, "SPAuth") + return + } + for zb0002 > 0 { + zb0002-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err, "SPAuth") + return + } + switch msgp.UnsafeString(field) { + case "TenantID": + z.SPAuth.TenantID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "SPAuth", "TenantID") + return + } + case "ClientID": + z.SPAuth.ClientID, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "SPAuth", "ClientID") + return + } + case "ClientSecret": + z.SPAuth.ClientSecret, bts, err = msgp.ReadStringBytes(bts) + if err != nil { + err = msgp.WrapError(err, "SPAuth", "ClientSecret") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err, "SPAuth") + return + } + } + } default: bts, err = msgp.Skip(bts) if err != nil { @@ -255,6 +538,6 @@ func (z *TierAzure) UnmarshalMsg(bts []byte) (o []byte, err error) { // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *TierAzure) Msgsize() (s int) { - s = 1 + 9 + msgp.StringPrefixSize + len(z.Endpoint) + 12 + msgp.StringPrefixSize + len(z.AccountName) + 11 + msgp.StringPrefixSize + len(z.AccountKey) + 7 + msgp.StringPrefixSize + len(z.Bucket) + 7 + msgp.StringPrefixSize + len(z.Prefix) + 7 + msgp.StringPrefixSize + len(z.Region) + 13 + msgp.StringPrefixSize + len(z.StorageClass) + s = 1 + 9 + msgp.StringPrefixSize + len(z.Endpoint) + 12 + msgp.StringPrefixSize + len(z.AccountName) + 11 + msgp.StringPrefixSize + len(z.AccountKey) + 7 + msgp.StringPrefixSize + len(z.Bucket) + 7 + msgp.StringPrefixSize + len(z.Prefix) + 7 + msgp.StringPrefixSize + len(z.Region) + 13 + msgp.StringPrefixSize + len(z.StorageClass) + 7 + 1 + 9 + msgp.StringPrefixSize + len(z.SPAuth.TenantID) + 9 + msgp.StringPrefixSize + len(z.SPAuth.ClientID) + 13 + msgp.StringPrefixSize + len(z.SPAuth.ClientSecret) return } diff --git a/tier-azure_gen_test.go b/tier-azure_gen_test.go index 45ddc67..6c13b77 100644 --- a/tier-azure_gen_test.go +++ b/tier-azure_gen_test.go @@ -9,6 +9,119 @@ import ( "github.com/tinylib/msgp/msgp" ) +func TestMarshalUnmarshalServicePrincipalAuth(t *testing.T) { + v := ServicePrincipalAuth{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgServicePrincipalAuth(b *testing.B) { + v := ServicePrincipalAuth{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgServicePrincipalAuth(b *testing.B) { + v := ServicePrincipalAuth{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalServicePrincipalAuth(b *testing.B) { + v := ServicePrincipalAuth{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodeServicePrincipalAuth(t *testing.T) { + v := ServicePrincipalAuth{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodeServicePrincipalAuth Msgsize() is inaccurate") + } + + vn := ServicePrincipalAuth{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodeServicePrincipalAuth(b *testing.B) { + v := ServicePrincipalAuth{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodeServicePrincipalAuth(b *testing.B) { + v := ServicePrincipalAuth{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + func TestMarshalUnmarshalTierAzure(t *testing.T) { v := TierAzure{} bts, err := v.MarshalMsg(nil)