From 3414ac81c858396f4388421504e9c02c9bb5b5fa Mon Sep 17 00:00:00 2001 From: Jakob Beckmann Date: Fri, 7 Oct 2022 17:06:37 +0200 Subject: [PATCH] fix(#163): change update (POST) request on roles to correctly use defaults fields if not provided, and add PATCH for role path --- path_role.go | 158 ++++++++++++++++++++++++++++++++-------- path_role_test.go | 179 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 307 insertions(+), 30 deletions(-) diff --git a/path_role.go b/path_role.go index 52918905..da38a108 100644 --- a/path_role.go +++ b/path_role.go @@ -94,6 +94,7 @@ default: %q Callbacks: map[logical.Operation]framework.OperationFunc{ logical.CreateOperation: b.pathRoleCreateUpdate, logical.UpdateOperation: b.pathRoleCreateUpdate, + logical.PatchOperation: b.pathRolePatch, logical.ReadOperation: b.pathRoleRead, logical.DeleteOperation: b.pathRoleDelete, }, @@ -220,17 +221,96 @@ func (b *kubeAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical b.l.Lock() defer b.l.Unlock() + role := &roleStorageEntry{} + + if err := role.ParseTokenFields(req, data); err != nil { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } + + // Handle upgrade cases + { + if err := tokenutil.UpgradeValue(data, "policies", "token_policies", &role.Policies, &role.TokenPolicies); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if err := tokenutil.UpgradeValue(data, "bound_cidrs", "token_bound_cidrs", &role.BoundCIDRs, &role.TokenBoundCIDRs); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if err := tokenutil.UpgradeValue(data, "num_uses", "token_num_uses", &role.NumUses, &role.TokenNumUses); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if err := tokenutil.UpgradeValue(data, "ttl", "token_ttl", &role.TTL, &role.TokenTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if err := tokenutil.UpgradeValue(data, "max_ttl", "token_max_ttl", &role.MaxTTL, &role.TokenMaxTTL); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + if err := tokenutil.UpgradeValue(data, "period", "token_period", &role.Period, &role.TokenPeriod); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + } + + if err := role.validateTokenLifetimes(b); err != nil { + return err, nil + } + + var resp *logical.Response + if role.TokenMaxTTL > b.System().MaxLeaseTTL() { + resp = &logical.Response{} + resp.AddWarning("max_ttl is greater than the system or backend mount's maximum TTL value; issued tokens' max TTL value will be truncated") + } + + role.ServiceAccountNames = data.Get("bound_service_account_names").([]string) + role.ServiceAccountNamespaces = data.Get("bound_service_account_namespaces").([]string) + + if err := role.validateServiceAccountMetadata(); err != nil { + return err, nil + } + + role.Audience = data.Get("audience").(string) + role.AliasNameSource = data.Get("alias_name_source").(string) + + if err := validateAliasNameSource(role.AliasNameSource); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + // Store the entry. + entry, err := logical.StorageEntryJSON("role/"+strings.ToLower(roleName), role) + if err != nil { + return nil, err + } + if entry == nil { + return nil, fmt.Errorf("failed to create storage entry for role %s", roleName) + } + if err = req.Storage.Put(ctx, entry); err != nil { + return nil, err + } + + return resp, nil +} + +// pathRolePatch patches an existing role with provided options +func (b *kubeAuthBackend) pathRolePatch(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + roleName := data.Get("name").(string) + if roleName == "" { + return logical.ErrorResponse("missing role name"), nil + } + + b.l.Lock() + defer b.l.Unlock() + // Check if the role already exists role, err := b.role(ctx, req.Storage, roleName) if err != nil { return nil, err } - // Create a new entry object if this is a CreateOperation - if role == nil && req.Operation == logical.CreateOperation { - role = &roleStorageEntry{} - } else if role == nil { - return nil, fmt.Errorf("role entry not found during update operation") + if role == nil { + return logical.ErrorResponse("Unable to fetch role entry to patch"), nil } if err := role.ParseTokenFields(req, data); err != nil { @@ -264,15 +344,8 @@ func (b *kubeAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical } } - if role.TokenPeriod > b.System().MaxLeaseTTL() { - return logical.ErrorResponse(fmt.Sprintf("token period of '%q' is greater than the backend's maximum lease TTL of '%q'", role.TokenPeriod.String(), b.System().MaxLeaseTTL().String())), nil - } - - // Check that the TTL value provided is less than the MaxTTL. - // Sanitizing the TTL and MaxTTL is not required now and can be performed - // at credential issue time. - if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL { - return logical.ErrorResponse("token ttl should not be greater than token max ttl"), nil + if err := role.validateTokenLifetimes(b); err != nil { + return err, nil } var resp *logical.Response @@ -286,27 +359,15 @@ func (b *kubeAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical } else if req.Operation == logical.CreateOperation { role.ServiceAccountNames = data.Get("bound_service_account_names").([]string) } - // Verify names was not empty - if len(role.ServiceAccountNames) == 0 { - return logical.ErrorResponse("%q can not be empty", "bound_service_account_names"), nil - } - // Verify * was not set with other data - if len(role.ServiceAccountNames) > 1 && strutil.StrListContains(role.ServiceAccountNames, "*") { - return logical.ErrorResponse("can not mix %q with values", "*"), nil - } if namespaces, ok := data.GetOk("bound_service_account_namespaces"); ok { role.ServiceAccountNamespaces = namespaces.([]string) } else if req.Operation == logical.CreateOperation { role.ServiceAccountNamespaces = data.Get("bound_service_account_namespaces").([]string) } - // Verify namespaces is not empty - if len(role.ServiceAccountNamespaces) == 0 { - return logical.ErrorResponse("%q can not be empty", "bound_service_account_namespaces"), nil - } - // Verify * was not set with other data - if len(role.ServiceAccountNamespaces) > 1 && strutil.StrListContains(role.ServiceAccountNamespaces, "*") { - return logical.ErrorResponse("can not mix %q with values", "*"), nil + + if err := role.validateServiceAccountMetadata(); err != nil { + return err, nil } // optional audience field @@ -345,6 +406,45 @@ func (b *kubeAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical return resp, nil } +func (role *roleStorageEntry) validateTokenLifetimes(backend *kubeAuthBackend) *logical.Response { + if role.TokenPeriod > backend.System().MaxLeaseTTL() { + return logical.ErrorResponse( + fmt.Sprintf( + "token period of '%q' is greater than the backend's maximum lease TTL of '%q'", + role.TokenPeriod.String(), + backend.System().MaxLeaseTTL().String(), + ), + ) + } + + // Check that the TTL value provided is less than the MaxTTL. + // Sanitizing the TTL and MaxTTL is not required now and can be performed + // at credential issue time. + if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL { + return logical.ErrorResponse("token ttl should not be greater than token max ttl") + } + return nil +} + +func (role *roleStorageEntry) validateServiceAccountMetadata() *logical.Response { + if len(role.ServiceAccountNames) == 0 { + return logical.ErrorResponse("%q can not be empty", "bound_service_account_names") + } + // Verify * was not set with other data + if len(role.ServiceAccountNames) > 1 && strutil.StrListContains(role.ServiceAccountNames, "*") { + return logical.ErrorResponse("can not mix %q with values", "*") + } + // Verify namespaces is not empty + if len(role.ServiceAccountNamespaces) == 0 { + return logical.ErrorResponse("%q can not be empty", "bound_service_account_namespaces") + } + // Verify * was not set with other data + if len(role.ServiceAccountNamespaces) > 1 && strutil.StrListContains(role.ServiceAccountNamespaces, "*") { + return logical.ErrorResponse("can not mix %q with values", "*") + } + return nil +} + // roleStorageEntry stores all the options that are set on an role type roleStorageEntry struct { tokenutil.TokenParams diff --git a/path_role_test.go b/path_role_test.go index 3b63113c..ab02084f 100644 --- a/path_role_test.go +++ b/path_role_test.go @@ -15,6 +15,10 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) +var ( + errBoundNamespacesEmpty = fmt.Errorf(`"bound_service_account_namespaces" can not be empty`) +) + func getBackend(t *testing.T) (logical.Backend, logical.Storage) { defaultLeaseTTLVal := time.Hour * 12 maxLeaseTTLVal := time.Hour * 24 @@ -314,6 +318,179 @@ func TestPath_Delete(t *testing.T) { } func TestPath_Update(t *testing.T) { + testCases := map[string]struct { + storageData map[string]interface{} + requestData map[string]interface{} + expected *roleStorageEntry + wantErr error + }{ + "default": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "policies": []string{"test"}, + "period": 1 * time.Second, + "ttl": 1 * time.Second, + "num_uses": 12, + "max_ttl": 5 * time.Second, + "alias_name_source": aliasNameSourceDefault, + }, + requestData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "alias_name_source": aliasNameSourceDefault, + "policies": []string{"bar", "foo"}, + "period": "3s", + "ttl": "1s", + "num_uses": 12, + "max_ttl": "5s", + }, + expected: &roleStorageEntry{ + TokenParams: tokenutil.TokenParams{ + TokenPolicies: []string{"bar", "foo"}, + TokenPeriod: 3 * time.Second, + TokenTTL: 1 * time.Second, + TokenMaxTTL: 5 * time.Second, + TokenNumUses: 12, + TokenBoundCIDRs: nil, + }, + Policies: []string{"bar", "foo"}, + Period: 3 * time.Second, + ServiceAccountNames: []string{"name"}, + ServiceAccountNamespaces: []string{"namespace"}, + TTL: 1 * time.Second, + MaxTTL: 5 * time.Second, + NumUses: 12, + BoundCIDRs: nil, + AliasNameSource: aliasNameSourceDefault, + }, + wantErr: nil, + }, + "use-defaults-on-missing": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "policies": []string{"test"}, + "period": 1 * time.Second, + "ttl": 1 * time.Second, + "num_uses": 12, + "max_ttl": 5 * time.Second, + "alias_name_source": aliasNameSourceSAName, + }, + requestData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "policies": []string{"bar", "foo"}, + "period": "3s", + "ttl": "1s", + "num_uses": 12, + "max_ttl": "5s", + }, + expected: &roleStorageEntry{ + TokenParams: tokenutil.TokenParams{ + TokenPolicies: []string{"bar", "foo"}, + TokenPeriod: 3 * time.Second, + TokenTTL: 1 * time.Second, + TokenMaxTTL: 5 * time.Second, + TokenNumUses: 12, + TokenBoundCIDRs: nil, + }, + Policies: []string{"bar", "foo"}, + Period: 3 * time.Second, + ServiceAccountNames: []string{"name"}, + ServiceAccountNamespaces: []string{"namespace"}, + TTL: 1 * time.Second, + MaxTTL: 5 * time.Second, + NumUses: 12, + BoundCIDRs: nil, + AliasNameSource: aliasNameSourceDefault, + }, + wantErr: nil, + }, + "missing-required-data": { + storageData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "bound_service_account_namespaces": []string{"namespace"}, + "policies": []string{"test"}, + "period": 1 * time.Second, + "ttl": 1 * time.Second, + "num_uses": 12, + "max_ttl": 5 * time.Second, + "alias_name_source": aliasNameSourceDefault, + }, + requestData: map[string]interface{}{ + "bound_service_account_names": []string{"name"}, + "alias_name_source": aliasNameSourceDefault, + "policies": []string{"bar", "foo"}, + "period": "3s", + "ttl": "1s", + "num_uses": 12, + "max_ttl": "5s", + }, + wantErr: errBoundNamespacesEmpty, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + b, storage := getBackend(t) + path := fmt.Sprintf("role/%s", name) + + data, err := json.Marshal(tc.storageData) + if err != nil { + t.Fatal(err) + } + + entry := &logical.StorageEntry{ + Key: path, + Value: data, + SealWrap: false, + } + if err := storage.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: path, + Storage: storage, + Data: tc.requestData, + } + + resp, err := b.HandleRequest(context.Background(), req) + + if tc.wantErr != nil { + var actual error + if err != nil { + actual = err + } else if resp != nil && resp.IsError() { + actual = resp.Error() + } else { + t.Fatalf("expected error") + } + + if tc.wantErr.Error() != actual.Error() { + t.Fatalf("expected err %q, actual %q", tc.wantErr, actual) + } + } else { + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + actual, err := b.(*kubeAuthBackend).role(context.Background(), storage, name) + if err != nil { + t.Fatal(err) + } + + if diff := deep.Equal(tc.expected, actual); diff != nil { + t.Fatal(diff) + } + } + }) + } +} + +func TestPath_Patch(t *testing.T) { testCases := map[string]struct { storageData map[string]interface{} requestData map[string]interface{} @@ -444,7 +621,7 @@ func TestPath_Update(t *testing.T) { } req := &logical.Request{ - Operation: logical.UpdateOperation, + Operation: logical.PatchOperation, Path: path, Storage: storage, Data: tc.requestData,