Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update awsconfig #50561

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 106 additions & 68 deletions lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package awsconfig
import (
"context"
"log/slog"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
Expand Down Expand Up @@ -47,16 +48,21 @@ const (
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)

// AssumeRoleClientProviderFunc provides an AWS STS assume role API client.
type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient

// assumeRole is an AWS role ARN to assume, optionally with an external ID.
type assumeRole struct {
roleARN string
externalID string
}

// options is a struct of additional options for assuming an AWS role
// when construction an underlying AWS config.
type options struct {
// baseConfigis a config to use instead of the default config for an
// AWS region, which is used to enable role chaining.
baseConfig *aws.Config
// assumeRoleARN is the AWS IAM Role ARN to assume.
assumeRoleARN string
// assumeRoleExternalID is used to assume an external AWS IAM Role.
assumeRoleExternalID string
// assumeRoles are AWS IAM roles that should be assumed one by one in order,
// as a chain of assumed roles.
assumeRoles []assumeRole
// credentialsSource describes which source to use to fetch credentials.
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
Expand All @@ -67,22 +73,43 @@ type options struct {
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
// assumeRoleClientProvider sets the STS assume role client provider func.
assumeRoleClientProvider AssumeRoleClientProviderFunc
}

func (a *options) checkAndSetDefaults() error {
switch a.credentialsSource {
func buildOptions(optFns ...OptionsFn) (*options, error) {
var opts options
for _, optFn := range optFns {
optFn(&opts)
}
if err := opts.checkAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &opts, nil
}

func (o *options) checkAndSetDefaults() error {
switch o.credentialsSource {
case credentialsSourceAmbient:
if a.integration != "" {
if o.integration != "" {
return trace.BadParameter("integration and ambient credentials cannot be used at the same time")
}
case credentialsSourceIntegration:
if a.integration == "" {
if o.integration == "" {
return trace.BadParameter("missing integration name")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}

if o.assumeRoleClientProvider == nil {
o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
return sts.NewFromConfig(cfg, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
}
}

return nil
}

Expand All @@ -93,8 +120,14 @@ type OptionsFn func(*options)
// WithAssumeRole configures options needed for assuming an AWS role.
func WithAssumeRole(roleARN, externalID string) OptionsFn {
return func(options *options) {
options.assumeRoleARN = roleARN
options.assumeRoleExternalID = externalID
if roleARN == "" {
// ignore empty role ARN for caller convenience.
return
}
options.assumeRoles = append(options.assumeRoles, assumeRole{
roleARN: roleARN,
externalID: externalID,
})
}
}

Expand Down Expand Up @@ -146,96 +179,101 @@ func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) O
}
}

// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to
// assume roles.
func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn {
return func(options *options) {
options.assumeRoleClientProvider = fn
}
}

// GetConfig returns an AWS config for the specified region, optionally
// assuming AWS IAM Roles.
func GetConfig(ctx context.Context, region string, opts ...OptionsFn) (aws.Config, error) {
var options options
for _, opt := range opts {
opt(&options)
}
if options.baseConfig == nil {
cfg, err := getConfigForRegion(ctx, region, options)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
options.baseConfig = &cfg
func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Config, error) {
opts, err := buildOptions(optFns...)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
if options.assumeRoleARN == "" {
return *options.baseConfig, nil

cfg, err := getBaseConfig(ctx, region, opts)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
return getConfigForRole(ctx, region, options)
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider)
}

// ambientConfigProvider loads a new config using the environment variables.
func ambientConfigProvider(region string, cred aws.CredentialsProvider, options options) (aws.Config, error) {
opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(context.Background(), opts...)
// loadDefaultConfig loads a new config.
func loadDefaultConfig(ctx context.Context, region string, cred aws.CredentialsProvider, opts *options) (aws.Config, error) {
configOpts := buildConfigOptions(region, cred, opts)
cfg, err := config.LoadDefaultConfig(ctx, configOpts...)
return cfg, trace.Wrap(err)
}

func buildConfigOptions(region string, cred aws.CredentialsProvider, options options) []func(*config.LoadOptions) error {
opts := []func(*config.LoadOptions) error{
func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *options) []func(*config.LoadOptions) error {
configOpts := []func(*config.LoadOptions) error{
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
}
if modules.GetModules().IsBoringBinary() {
opts = append(opts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
}
if options.customRetryer != nil {
opts = append(opts, config.WithRetryer(options.customRetryer))
if opts.customRetryer != nil {
configOpts = append(configOpts, config.WithRetryer(opts.customRetryer))
}
if options.maxRetries != nil {
opts = append(opts, config.WithRetryMaxAttempts(*options.maxRetries))
if opts.maxRetries != nil {
configOpts = append(configOpts, config.WithRetryMaxAttempts(*opts.maxRetries))
}
return opts
return configOpts
}

// getConfigForRegion returns AWS config for the specified region.
func getConfigForRegion(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
}

// getBaseConfig returns an AWS config without assuming any roles.
func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
var cred aws.CredentialsProvider
if options.credentialsSource == credentialsSourceIntegration {
if options.integrationCredentialsProvider == nil {
if opts.credentialsSource == credentialsSourceIntegration {
if opts.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", options.integration)
slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration)
var err error
cred, err = options.integrationCredentialsProvider(ctx, region, options.integration)
cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from environment", "region", region)
slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region)
}

cfg, err := ambientConfigProvider(region, cred, options)
cfg, err := loadDefaultConfig(ctx, region, cred, opts)
return cfg, trace.Wrap(err)
}

// getConfigForRole returns an AWS config for the specified region and role.
func getConfigForRole(ctx context.Context, region string, options options) (aws.Config, error) {
if err := options.checkAndSetDefaults(); err != nil {
return aws.Config{}, trace.Wrap(err)
func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []assumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) {
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
}

stsClient := sts.NewFromConfig(*options.baseConfig, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
cred := stscreds.NewAssumeRoleProvider(stsClient, options.assumeRoleARN, func(aro *stscreds.AssumeRoleOptions) {
if options.assumeRoleExternalID != "" {
aro.ExternalID = aws.String(options.assumeRoleExternalID)
if len(roles) > 0 {
// no point caching every assumed role in the chain, we can just cache
// the last one.
cfg.Credentials = aws.NewCredentialsCache(cfg.Credentials, func(cacheOpts *aws.CredentialsCacheOptions) {
// expire early to avoid expiration race.
cacheOpts.ExpiryWindow = 5 * time.Minute
})
if _, err := cfg.Credentials.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}
})
if _, err := cred.Retrieve(ctx); err != nil {
return aws.Config{}, trace.Wrap(err)
}
return cfg, nil
}

opts := buildConfigOptions(region, cred, options)
cfg, err := config.LoadDefaultConfig(ctx, opts...)
return cfg, trace.Wrap(err)
func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient, role assumeRole) aws.CredentialsProvider {
slog.DebugContext(ctx, "Initializing AWS session for assumed role",
"assumed_role", role.roleARN,
)
return stscreds.NewAssumeRoleProvider(clt, role.roleARN, func(aro *stscreds.AssumeRoleOptions) {
if role.externalID != "" {
aro.ExternalID = aws.String(role.externalID)
}
})
}
Loading
Loading