From 41f79430003ed7e364707f8abab5b297cb797c7d Mon Sep 17 00:00:00 2001 From: Albert Lloveras Date: Mon, 22 Jul 2024 16:49:59 +1000 Subject: [PATCH 1/2] fixup(sg-resolver): Allow multiple SGs with the same Name tag --- pkg/algorithm/slices.go | 18 ++++ pkg/algorithm/slices_test.go | 46 +++++++++ pkg/networking/security_group_resolver.go | 46 ++++++++- .../security_group_resolver_test.go | 94 ++++++++++++++++--- 4 files changed, 188 insertions(+), 16 deletions(-) create mode 100644 pkg/algorithm/slices.go create mode 100644 pkg/algorithm/slices_test.go diff --git a/pkg/algorithm/slices.go b/pkg/algorithm/slices.go new file mode 100644 index 0000000000..a82eae5832 --- /dev/null +++ b/pkg/algorithm/slices.go @@ -0,0 +1,18 @@ +package algorithm + +import "cmp" + +// RemoveSliceDuplicates returns a copy of the slice without duplicate entries. +func RemoveSliceDuplicates[S ~[]E, E cmp.Ordered](s S) []E { + result := make([]E, 0, len(s)) + found := make(map[E]struct{}, len(s)) + + for _, x := range s { + if _, ok := found[x]; !ok { + found[x] = struct{}{} + result = append(result, x) + } + } + + return result +} diff --git a/pkg/algorithm/slices_test.go b/pkg/algorithm/slices_test.go new file mode 100644 index 0000000000..decf9deb6e --- /dev/null +++ b/pkg/algorithm/slices_test.go @@ -0,0 +1,46 @@ +package algorithm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_RemoveSliceDuplicates(t *testing.T) { + type args struct { + data []string + } + tests := []struct { + name string + args args + want []string + }{ + { + name: "empty", + args: args{ + data: []string{}, + }, + want: []string{}, + }, + { + name: "no duplicate entries", + args: args{ + data: []string{"a", "b", "c", "d"}, + }, + want: []string{"a", "b", "c", "d"}, + }, + { + name: "with duplicates", + args: args{ + data: []string{"a", "b", "a", "c", "b"}, + }, + want: []string{"a", "b", "c"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RemoveSliceDuplicates(tt.args.data) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/networking/security_group_resolver.go b/pkg/networking/security_group_resolver.go index 402d1795f0..0243480409 100644 --- a/pkg/networking/security_group_resolver.go +++ b/pkg/networking/security_group_resolver.go @@ -7,6 +7,7 @@ import ( awssdk "github.com/aws/aws-sdk-go/aws" ec2sdk "github.com/aws/aws-sdk-go/service/ec2" "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/algorithm" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" ) @@ -35,6 +36,7 @@ type defaultSecurityGroupResolver struct { func (r *defaultSecurityGroupResolver) ResolveViaNameOrID(ctx context.Context, sgNameOrIDs []string) ([]string, error) { sgIDs, sgNames := r.splitIntoSgNameAndIDs(sgNameOrIDs) var resolvedSGs []*ec2sdk.SecurityGroup + if len(sgIDs) > 0 { sgs, err := r.resolveViaGroupID(ctx, sgIDs) if err != nil { @@ -42,6 +44,7 @@ func (r *defaultSecurityGroupResolver) ResolveViaNameOrID(ctx context.Context, s } resolvedSGs = append(resolvedSGs, sgs...) } + if len(sgNames) > 0 { sgs, err := r.resolveViaGroupName(ctx, sgNames) if err != nil { @@ -49,13 +52,12 @@ func (r *defaultSecurityGroupResolver) ResolveViaNameOrID(ctx context.Context, s } resolvedSGs = append(resolvedSGs, sgs...) } + resolvedSGIDs := make([]string, 0, len(resolvedSGs)) for _, sg := range resolvedSGs { resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId)) } - if len(resolvedSGIDs) != len(sgNameOrIDs) { - return nil, errors.Errorf("couldn't find all securityGroups, nameOrIDs: %v, found: %v", sgNameOrIDs, resolvedSGIDs) - } + return resolvedSGIDs, nil } @@ -63,14 +65,31 @@ func (r *defaultSecurityGroupResolver) resolveViaGroupID(ctx context.Context, sg req := &ec2sdk.DescribeSecurityGroupsInput{ GroupIds: awssdk.StringSlice(sgIDs), } + sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req) if err != nil { return nil, err } + + resolvedSGIDs := make([]string, 0, len(sgs)) + for _, sg := range sgs { + resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId)) + } + + if len(sgIDs) != len(resolvedSGIDs) { + return nil, errors.Errorf( + "couldn't find all securityGroups, requested ids: [%s], found: [%s]", + strings.Join(sgIDs, ", "), + strings.Join(resolvedSGIDs, ", "), + ) + } + return sgs, nil } func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context, sgNames []string) ([]*ec2sdk.SecurityGroup, error) { + sgNames = algorithm.RemoveSliceDuplicates(sgNames) + req := &ec2sdk.DescribeSecurityGroupsInput{ Filters: []*ec2sdk.Filter{ { @@ -83,10 +102,31 @@ func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context, }, }, } + sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req) if err != nil { return nil, err } + + resolvedSGNames := make([]string, 0, len(sgs)) + for _, sg := range sgs { + for _, tag := range sg.Tags { + if awssdk.StringValue(tag.Key) == "Name" { + resolvedSGNames = append(resolvedSGNames, awssdk.StringValue(tag.Value)) + } + } + } + + resolvedSGNames = algorithm.RemoveSliceDuplicates(resolvedSGNames) + + if len(sgNames) != len(resolvedSGNames) { + return nil, errors.Errorf( + "couldn't find all securityGroups, requested names: [%s], found: [%s]", + strings.Join(sgNames, ", "), + strings.Join(resolvedSGNames, ", "), + ) + } + return sgs, nil } diff --git a/pkg/networking/security_group_resolver_test.go b/pkg/networking/security_group_resolver_test.go index ad155b75a1..2bd9dcd63f 100644 --- a/pkg/networking/security_group_resolver_test.go +++ b/pkg/networking/security_group_resolver_test.go @@ -88,9 +88,15 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { resp: []*ec2sdk.SecurityGroup{ { GroupId: awssdk.String("sg-0912f63b"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, }, { GroupId: awssdk.String("sg-08982de7"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group two")}, + }, }, }, }, @@ -101,6 +107,50 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { "sg-0912f63b", }, }, + { + name: "single name multiple ids", + args: args{ + nameOrIDs: []string{ + "sg group one", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{ + "sg group one", + }), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, + }, + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-id1"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, + }, + { + GroupId: awssdk.String("sg-id2"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, + }, + }, + }, + }, + }, + want: []string{ + "sg-id1", + "sg-id2", + }, + }, { name: "mixed group name and id", args: args{ @@ -127,6 +177,9 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { resp: []*ec2sdk.SecurityGroup{ { GroupId: awssdk.String("sg-0912f63b"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, }, }, }, @@ -205,13 +258,34 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { wantErr: errors.New("Describe.Error: unable to describe security groups"), }, { - name: "unable to resolve all security groups", + name: "unable to resolve all security group ids", args: args{ nameOrIDs: []string{ - "sg group one", "sg-id1", "sg-id404", }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice([]string{"sg-id1", "sg-id404"}), + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-id1"), + }, + }, + }, + }, + }, + wantErr: errors.New("couldn't find all securityGroups, requested ids: [sg-id1, sg-id404], found: [sg-id1]"), + }, + { + name: "unable to resolve all security groups names", + args: args{ + nameOrIDs: []string{ + "sg group one", + "sg group two", + }, describeSGCalls: []describeSecurityGroupsAsListCall{ { req: &ec2sdk.DescribeSecurityGroupsInput{ @@ -220,6 +294,7 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { Name: awssdk.String("tag:Name"), Values: awssdk.StringSlice([]string{ "sg group one", + "sg group two", }), }, { @@ -231,22 +306,15 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { resp: []*ec2sdk.SecurityGroup{ { GroupId: awssdk.String("sg-0912f63b"), - }, - }, - }, - { - req: &ec2sdk.DescribeSecurityGroupsInput{ - GroupIds: awssdk.StringSlice([]string{"sg-id1", "sg-id404"}), - }, - resp: []*ec2sdk.SecurityGroup{ - { - GroupId: awssdk.String("sg-id1"), + Tags: []*ec2sdk.Tag{ + {Key: awssdk.String("Name"), Value: awssdk.String("sg group one")}, + }, }, }, }, }, }, - wantErr: errors.New("couldn't find all securityGroups, nameOrIDs: [sg group one sg-id1 sg-id404], found: [sg-id1 sg-0912f63b]"), + wantErr: errors.New("couldn't find all securityGroups, requested names: [sg group one, sg group two], found: [sg group one]"), }, } From 12a7953a3112c0478a3f5178bf16b5650910e14d Mon Sep 17 00:00:00 2001 From: Albert Lloveras Date: Fri, 16 Aug 2024 13:51:58 +1000 Subject: [PATCH 2/2] PR Feedback: Re-work error reporting and tests --- pkg/networking/security_group_resolver.go | 30 ++++----- .../security_group_resolver_test.go | 63 ++++++++++++------- 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/pkg/networking/security_group_resolver.go b/pkg/networking/security_group_resolver.go index 0243480409..3b807729d6 100644 --- a/pkg/networking/security_group_resolver.go +++ b/pkg/networking/security_group_resolver.go @@ -34,23 +34,31 @@ type defaultSecurityGroupResolver struct { } func (r *defaultSecurityGroupResolver) ResolveViaNameOrID(ctx context.Context, sgNameOrIDs []string) ([]string, error) { - sgIDs, sgNames := r.splitIntoSgNameAndIDs(sgNameOrIDs) var resolvedSGs []*ec2sdk.SecurityGroup + var errMessages []string + + sgIDs, sgNames := r.splitIntoSgNameAndIDs(sgNameOrIDs) if len(sgIDs) > 0 { sgs, err := r.resolveViaGroupID(ctx, sgIDs) if err != nil { - return nil, err + errMessages = append(errMessages, err.Error()) + } else { + resolvedSGs = append(resolvedSGs, sgs...) } - resolvedSGs = append(resolvedSGs, sgs...) } if len(sgNames) > 0 { sgs, err := r.resolveViaGroupName(ctx, sgNames) if err != nil { - return nil, err + errMessages = append(errMessages, err.Error()) + } else { + resolvedSGs = append(resolvedSGs, sgs...) } - resolvedSGs = append(resolvedSGs, sgs...) + } + + if len(errMessages) > 0 { + return nil, errors.Errorf("couldn't find all security groups: %s", strings.Join(errMessages, ", ")) } resolvedSGIDs := make([]string, 0, len(resolvedSGs)) @@ -77,11 +85,7 @@ func (r *defaultSecurityGroupResolver) resolveViaGroupID(ctx context.Context, sg } if len(sgIDs) != len(resolvedSGIDs) { - return nil, errors.Errorf( - "couldn't find all securityGroups, requested ids: [%s], found: [%s]", - strings.Join(sgIDs, ", "), - strings.Join(resolvedSGIDs, ", "), - ) + return nil, errors.Errorf("requested ids [%s] but found [%s]", strings.Join(sgIDs, ", "), strings.Join(resolvedSGIDs, ", ")) } return sgs, nil @@ -120,11 +124,7 @@ func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context, resolvedSGNames = algorithm.RemoveSliceDuplicates(resolvedSGNames) if len(sgNames) != len(resolvedSGNames) { - return nil, errors.Errorf( - "couldn't find all securityGroups, requested names: [%s], found: [%s]", - strings.Join(sgNames, ", "), - strings.Join(resolvedSGNames, ", "), - ) + return nil, errors.Errorf("requested names [%s] but found [%s]", strings.Join(sgNames, ", "), strings.Join(resolvedSGNames, ", ")) } return sgs, nil diff --git a/pkg/networking/security_group_resolver_test.go b/pkg/networking/security_group_resolver_test.go index 2bd9dcd63f..bd2c663c50 100644 --- a/pkg/networking/security_group_resolver_test.go +++ b/pkg/networking/security_group_resolver_test.go @@ -204,7 +204,6 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { name: "describe by id returns error", args: args{ nameOrIDs: []string{ - "sg group name", "sg-id", }, describeSGCalls: []describeSecurityGroupsAsListCall{ @@ -216,24 +215,21 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { }, }, }, - wantErr: errors.New("Describe.Error: unable to describe security groups"), + wantErr: errors.New("couldn't find all security groups: Describe.Error: unable to describe security groups"), }, { name: "describe by name returns error", args: args{ nameOrIDs: []string{ "sg group name", - "sg-id", }, describeSGCalls: []describeSecurityGroupsAsListCall{ { req: &ec2sdk.DescribeSecurityGroupsInput{ Filters: []*ec2sdk.Filter{ { - Name: awssdk.String("tag:Name"), - Values: awssdk.StringSlice([]string{ - "sg group name", - }), + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{"sg group name"}), }, { Name: awssdk.String("vpc-id"), @@ -243,22 +239,12 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { }, err: awserr.New("Describe.Error", "unable to describe security groups", nil), }, - { - req: &ec2sdk.DescribeSecurityGroupsInput{ - GroupIds: awssdk.StringSlice([]string{"sg-id"}), - }, - resp: []*ec2sdk.SecurityGroup{ - { - GroupId: awssdk.String("sg-id"), - }, - }, - }, }, }, - wantErr: errors.New("Describe.Error: unable to describe security groups"), + wantErr: errors.New("couldn't find all security groups: Describe.Error: unable to describe security groups"), }, { - name: "unable to resolve all security group ids", + name: "unable to resolve security groups by id", args: args{ nameOrIDs: []string{ "sg-id1", @@ -277,10 +263,10 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { }, }, }, - wantErr: errors.New("couldn't find all securityGroups, requested ids: [sg-id1, sg-id404], found: [sg-id1]"), + wantErr: errors.New("couldn't find all security groups: requested ids [sg-id1, sg-id404] but found [sg-id1]"), }, { - name: "unable to resolve all security groups names", + name: "unable to resolve security groups by name", args: args{ nameOrIDs: []string{ "sg group one", @@ -314,7 +300,40 @@ func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { }, }, }, - wantErr: errors.New("couldn't find all securityGroups, requested names: [sg group one, sg group two], found: [sg group one]"), + wantErr: errors.New("couldn't find all security groups: requested names [sg group one, sg group two] but found [sg group one]"), + }, + { + name: "unable to resolve all security groups by ids and names", + args: args{ + nameOrIDs: []string{ + "sg-08982de7", + "sg group one", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice([]string{"sg-08982de7"}), + }, + resp: []*ec2sdk.SecurityGroup{}, + }, + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{"sg group one"}), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, + }, + }, + resp: []*ec2sdk.SecurityGroup{}, + }, + }, + }, + wantErr: errors.New("couldn't find all security groups: requested ids [sg-08982de7] but found [], requested names [sg group one] but found []"), }, }