Skip to content

Commit

Permalink
Update awsconfig
Browse files Browse the repository at this point in the history
* Add a Cache for caching credentials, similar to SDK v1 session cache.
* Add a Provider interface that provides aws.Config
* Simplified role chaining options

Unlike our SDK v1 session cache, the SDK v2 implementation in this PR does not include region as a cache key.
There are regional AWS STS endpoints for lower latency calls, but the lowest latency path is to just grab credentials from the cache if we already have them - the region they were originally taken from doesn't matter.
  • Loading branch information
GavinFrazar committed Dec 24, 2024
1 parent 60aaa6d commit 39816e4
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 72 deletions.
173 changes: 107 additions & 66 deletions lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package awsconfig
import (
"context"
"log/slog"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
Expand Down Expand Up @@ -47,16 +48,21 @@ const (
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)

// AssumeRoleAPIClientFunc provides an AWS STS assume role API client.
type AssumeRoleAPIClientFunc func(aws.Config) stscreds.AssumeRoleAPIClient

// assumeRole is an AWS role ARN to assume, optionally with an external ID.
type assumeRole struct {
roleARN string
externalID string
}

// options is a struct of additional options for assuming an AWS role
// when construction an underlying AWS config.
type options struct {
// baseConfigis a config to use instead of the default config for an
// AWS region, which is used to enable role chaining.
baseConfig *aws.Config
// assumeRoleARN is the AWS IAM Role ARN to assume.
assumeRoleARN string
// assumeRoleExternalID is used to assume an external AWS IAM Role.
assumeRoleExternalID string
// assumeRoles are AWS IAM roles that should be assumed one by one in order,
// as a chain of assumed roles.
assumeRoles []assumeRole
// credentialsSource describes which source to use to fetch credentials.
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
Expand All @@ -67,22 +73,39 @@ type options struct {
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
// newAssumeRoleAPIClientFn sets the STS assume role client provider func.
newAssumeRoleAPIClientFn AssumeRoleAPIClientFunc
}

func buildOptions(optFns ...OptionsFn) (*options, error) {
var opts options
for _, optFn := range optFns {
optFn(&opts)
}
if err := opts.checkAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &opts, nil
}

func (a *options) checkAndSetDefaults() error {
switch a.credentialsSource {
func (o *options) checkAndSetDefaults() error {
switch o.credentialsSource {
case credentialsSourceAmbient:
if a.integration != "" {
if o.integration != "" {
return trace.BadParameter("integration and ambient credentials cannot be used at the same time")
}
case credentialsSourceIntegration:
if a.integration == "" {
if o.integration == "" {
return trace.BadParameter("missing integration name")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}

if o.newAssumeRoleAPIClientFn == nil {
o.newAssumeRoleAPIClientFn = newAssumeRoleAPIClient
}

return nil
}

Expand All @@ -93,8 +116,14 @@ type OptionsFn func(*options)
// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) OptionsFn {
return func(options *options) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
if roleARN == "" {
// ignore empty role ARN for caller convenience.
return
}
options.assumeRoles = append(options.assumeRoles, assumeRole{
roleARN: roleARN,
externalID: externalID,
})
}
}

Expand Down Expand Up @@ -146,96 +175,108 @@ func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) O
}
}

// WithAssumeRoleAPIClientFunc sets the STS API client factory func used to
// assume roles.
func WithAssumeRoleAPIClientFunc(fn AssumeRoleAPIClientFunc) OptionsFn {
return func(options *options) {
options.newAssumeRoleAPIClientFn = fn
}
}

// GetConfig returns an AWS config for the specified region, optionally
// assuming AWS IAM Roles.
func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) {
var options options
for _, opt := range opts {
opt(&options)
}
if options.baseConfig == nil {
cfg, err := getConfigForRegion(ctx, region, options)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
options.baseConfig = &cfg
func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) {
opts, err := buildOptions(optFns...)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
if options.assumeRoleARN == "" {
return *options.baseConfig, nil

cfg, err := getBaseConfig(ctx, region, opts)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
return getConfigForRole(ctx, region, options)
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.newAssumeRoleAPIClientFn)
}

// ambientConfigProvider loads a new config using the environment variables.
func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) {
opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
// loadDefaultConfig loads a new config.
func loadDefaultConfig(ctx context.Context, region string, cred aws.CredentialsProvider, opts *options) (aws.Config, error) {
configOpts := buildConfigOptions(region, cred, opts)
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
return cfg, trace.Wrap(err)
}

func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error {
opts := []func(*config.LoadOptions) error{
func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *options) []func(*config.LoadOptions) error {
configOpts := []func(*config.LoadOptions) error{
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
}
if modules.GetModules().IsBoringBinary() {
opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
}
if options.customRetryer != nil {
opts = append(opts, config.WithRetryer(options.customRetryer))
if opts.customRetryer != nil {
configOpts = append(configOpts, config.WithRetryer(opts.customRetryer))
}
if options.maxRetries != nil {
opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries))
if opts.maxRetries != nil {
configOpts = append(configOpts, config.WithRetryMaxAttempts(*opts.maxRetries))
}
return opts
return configOpts
}

// getConfigForRegion returns AWS config for the specified region.
func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}

// getBaseConfig returns an AWS config without assuming any roles.
func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
var cred aws.CredentialsProvider
if options.credentialsSource == credentialsSourceIntegration {
if options.integrationCredentialsProvider == nil {
if opts.credentialsSource == credentialsSourceIntegration {
if opts.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration)
var err error
cred, err = options.integrationCredentialsProvider(ctx, region, options.integration)
cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region)
}

cfg, err := ambientConfigProvider(region, cred, options)
cfg, err := loadDefaultConfig(ctx, region, cred, opts)
return cfg, trace.Wrap(err)
}

// getConfigForRole returns an AWS config for the specified region and role.
func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []assumeRole, newSTSCltFn AssumeRoleAPIClientFunc) (aws.Config, error) {
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(newSTSCltFn(cfg), r)
}
if len(roles) > 0 {
// no point caching every assumed role in the chain, we can just cache
// the last one.
cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, func(cacheOpts *aws.CredentialsCacheOptions) {
// expire early to avoid expiration race.
cacheOpts.ExpiryWindow = 5 * time.Minute
})
if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}
}
return cfg, nil
}

func newAssumeRoleAPIClient(cfg aws.Config) stscreds.AssumeRoleAPIClient {
return newSTSClient(cfg)
}

stsClient := sts.NewFromConfig(*options.baseConfig, func(o *sts.Options) {
func newSTSClient(cfg aws.Config) *sts.Client {
return sts.NewFromConfig(cfg, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
if options.assumeRoleExternalID != "" {
aro.ExternalID = aws.String(options.assumeRoleExternalID)
}

func getAssumeRoleProvider(clt stscreds.AssumeRoleAPIClient, role assumeRole) aws.CredentialsProvider {
return stscreds.NewAssumeRoleProvider(clt, role.roleARN, func(aro *stscreds.AssumeRoleOptions) {
if role.externalID != "" {
aro.ExternalID = aws.String(role.externalID)
}
})
if _, err := cred.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}

opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(ctx, opts...)
return cfg, trace.Wrap(err)
}
Loading

0 comments on commit 39816e4

Please sign in to comment.