Skip to content

Commit

Permalink
Merge pull request #8 from 0xsequence/fix_walking_slice
Browse files Browse the repository at this point in the history
Support walking slices, recursive types & prepare for fetching secrets in batches
  • Loading branch information
VojtechVitek authored Oct 4, 2024
2 parents 4863976 + 761d4f5 commit ade4468
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 64 deletions.
64 changes: 64 additions & 0 deletions collector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package cloudsecrets

import (
"errors"
"fmt"
"reflect"
"strings"
)

type secretField struct {
value reflect.Value
fieldPath string
secretName string
}

type collector struct {
fields []*secretField
err error
}

// Walks given reflect value recursively and collects any string fields with $SECRET: prefix.
func (g *collector) collectSecretFields(v reflect.Value, path string) {
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
return
}

// Dereference pointer
g.collectSecretFields(v.Elem(), path)

case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
g.collectSecretFields(field, fmt.Sprintf("%v.%v", path, v.Type().Field(i).Name))
}

case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
item := v.Index(i)
g.collectSecretFields(item, fmt.Sprintf("%v[%v]", path, i))
}

case reflect.String:
secretName, found := strings.CutPrefix(v.String(), "$SECRET:")
if !found {
return
}

if !v.CanSet() {
g.err = errors.Join(g.err, fmt.Errorf("can't set field %v", path))
return
}

g.fields = append(g.fields, &secretField{
value: v,
fieldPath: path,
secretName: secretName,
})

default:
return
}
}
87 changes: 87 additions & 0 deletions collector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package cloudsecrets

import (
"fmt"
"reflect"
"testing"
)

type dbConfig struct {
User string
Password string
}

type jwtSecret string

type config1 struct {
DB dbConfig
JWTSecrets []jwtSecret
unexported dbConfig
}

func TestCollectFields(t *testing.T) {
tt := []struct {
Input any
Out []string // field paths
Error bool
}{
{
Input: &config1{
DB: dbConfig{
User: "db-user",
Password: "db-password",
},
},
Out: []string{},
},
{
Input: &config1{
DB: dbConfig{
User: "db-user",
Password: "$SECRET:secretName",
},
JWTSecrets: []jwtSecret{"$SECRET:jwtSecret1", "$SECRET:jwtSecret2", "nope"},
},
Out: []string{"secretName", "jwtSecret1", "jwtSecret2"},
},
{
Input: &config1{
unexported: dbConfig{ // unexported fields can't be updated via reflect pkg
User: "db-user",
Password: "$SECRET:secretName", // match inside unexported field
},
},
Out: []string{},
Error: true, // expect error
},
}

for i, tc := range tt {
i, tc := i, tc
t.Run(fmt.Sprintf("tt[%v]", i), func(t *testing.T) {
v := reflect.ValueOf(tc.Input)

c := &collector{}
c.collectSecretFields(v, fmt.Sprintf("tt[%v].input", i))

if tc.Error {
if c.err == nil {
t.Error("expected error, got nil")
}
} else {
if c.err != nil {
t.Errorf("unexpected error: %v", c.err)
}
}

if len(c.fields) != len(tc.Out) {
t.Errorf("expected %v secrets, got %v", len(tc.Out), len(c.fields))
}
for i := 0; i < len(c.fields); i++ {
if c.fields[i].secretName != tc.Out[i] {
t.Errorf("collected field[%v].secretName=%v doesn't match tc.Out[%v]=%v", i, c.fields[i].secretName, i, tc.Out[i])
}
}
})
}
}
72 changes: 19 additions & 53 deletions hydrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import (
"context"
"fmt"
"reflect"
"strings"
"sync"

"github.com/0xsequence/go-cloudsecrets/gcp"
"github.com/0xsequence/go-cloudsecrets/nosecrets"
"golang.org/x/sync/errgroup"
)

// Hydrate recursively walks a given config (struct pointer) and hydrates all
Expand Down Expand Up @@ -39,74 +38,41 @@ func Hydrate(ctx context.Context, providerName string, config interface{}) error
}

v := reflect.ValueOf(config)
return hydrateStruct(ctx, provider, v)
return hydrateConfig(ctx, provider, v)
}

func hydrateStruct(ctx context.Context, provider secretsProvider, v reflect.Value) error {
func hydrateConfig(ctx context.Context, provider secretsProvider, v reflect.Value) error {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return fmt.Errorf("passed config is nil")
}

v = v.Elem()
}

if v.Kind() != reflect.Struct {
return fmt.Errorf("passed config must be struct, actual %s", v.Kind())
}

errCh := make(chan error)
wg := &sync.WaitGroup{}
hydrateStructFields(ctx, provider, v, wg, errCh)
go func() {
wg.Wait()
close(errCh)
}()

select {
case err, ok := <-errCh:
if !ok {
return nil
}
if err != nil {
return fmt.Errorf("walking struct fields: %w", err)
}
c := &collector{}
c.collectSecretFields(v, "config")
if c.err != nil {
return fmt.Errorf("failed to collect fields: %w", c.err)
}

return nil
}

func hydrateStructFields(ctx context.Context, provider secretsProvider, config reflect.Value, wg *sync.WaitGroup, errCh chan error) {
for i := 0; i < config.NumField(); i++ {
field := config.Field(i)
g := &errgroup.Group{}
for _, field := range c.fields {
field := field

if field.Kind() == reflect.Ptr {
if field.IsNil() {
continue
g.Go(func() error {
secretValue, err := provider.FetchSecret(ctx, field.secretName)
if err != nil {
return fmt.Errorf("failed to fetch secret %v=%q: %w", field.fieldPath, field.value.String(), err)
}
// Dereference pointer
field = field.Elem()
}

if field.Kind() == reflect.Struct {
hydrateStructFields(ctx, provider, field, wg, errCh)
continue
}
field.value.SetString(secretValue)

if field.Kind() == reflect.String && field.CanSet() {
secretName, found := strings.CutPrefix(field.String(), "$SECRET:")
if found {
wg.Add(1)
go func(fieldName string, field reflect.Value, secretName string) {
defer wg.Done()
secretValue, err := provider.FetchSecret(ctx, secretName)
if err != nil {
errCh <- fmt.Errorf("%v=%q: %w", fieldName, field.String(), err)
return
}
field.SetString(secretValue)
}(config.Type().Field(i).Name, field, secretName)
}
}
return nil
})
}

return g.Wait()
}
24 changes: 13 additions & 11 deletions hydrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ import (
)

type config struct {
DB db
Analytics analytics
Pass string
DB db
Analytics analytics
Pass string
JWTSecrets []string
}

type db struct {
Expand All @@ -30,13 +31,7 @@ type analytics struct {
func TestFailWhenPassedValueIsNotStruct(t *testing.T) {
input := "hello"

v := reflect.ValueOf(input)
provider := mock.NewSecretsProvider(map[string]string{
"dbPassword": "changethissecret",
"analyticsPassword": "AuthTokenSecret",
})

assert.Error(t, hydrateStruct(context.Background(), provider, v))
assert.Error(t, Hydrate(context.Background(), "", input))
}

func TestReplacePlaceholdersWithSecrets(t *testing.T) {
Expand All @@ -55,6 +50,8 @@ func TestReplacePlaceholdersWithSecrets(t *testing.T) {
"dbPassword": "changethissecret",
"analyticsPassword": "AuthTokenSecret",
"pass": "secret",
"jwtSecretV1": "some-old-secret",
"jwtSecretV2": "changeme-now",
},
conf: &config{
Pass: "$SECRET:pass",
Expand All @@ -68,6 +65,7 @@ func TestReplacePlaceholdersWithSecrets(t *testing.T) {
Server: "http://localhost:8000",
AuthToken: "$SECRET:analyticsPassword",
},
JWTSecrets: []string{"$SECRET:jwtSecretV2", "$SECRET:jwtSecretV1"},
},
wantErr: false,
wantConf: &config{
Expand All @@ -82,6 +80,10 @@ func TestReplacePlaceholdersWithSecrets(t *testing.T) {
Server: "http://localhost:8000",
AuthToken: "AuthTokenSecret",
},
JWTSecrets: []string{
"changeme-now",
"some-old-secret",
},
},
},
{
Expand Down Expand Up @@ -109,7 +111,7 @@ func TestReplacePlaceholdersWithSecrets(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := reflect.ValueOf(tt.conf)
err := hydrateStruct(ctx, mock.NewSecretsProvider(tt.storage), v)
err := hydrateConfig(ctx, mock.NewSecretsProvider(tt.storage), v)
if err != nil {
if tt.wantErr {
assert.Equal(t, tt.wantConf, tt.conf)
Expand Down

0 comments on commit ade4468

Please sign in to comment.