From 001cfa9ed059769d13f1d2e744eaac02e25f1c22 Mon Sep 17 00:00:00 2001 From: chr12c Date: Fri, 20 Dec 2024 17:46:33 +0000 Subject: [PATCH] vai provider service specific config --- provider-service/vai/internal/config.go | 58 --------------- .../vai/internal/config/config.go | 71 +++++++++++++++++++ 2 files changed, 71 insertions(+), 58 deletions(-) delete mode 100644 provider-service/vai/internal/config.go create mode 100644 provider-service/vai/internal/config/config.go diff --git a/provider-service/vai/internal/config.go b/provider-service/vai/internal/config.go deleted file mode 100644 index 9334f3b1..00000000 --- a/provider-service/vai/internal/config.go +++ /dev/null @@ -1,58 +0,0 @@ -package internal - -import ( - "fmt" - "github.com/sky-uk/kfp-operator/argo/common" -) - -type VAIProviderConfig struct { - Name string `yaml:"name"` - Parameters Parameters `yaml:"parameters"` -} - -type Parameters struct { - VaiProject string `yaml:"vaiProject"` - VaiLocation string `yaml:"vaiLocation"` - VaiJobServiceAccount string `yaml:"vaiJobServiceAccount"` - GcsEndpoint string `yaml:"gcsEndpoint"` - PipelineBucket string `yaml:"pipelineBucket"` - EventsourcePipelineEventsSubscription string `yaml:"eventsourcePipelineEventsSubscription"` - MaxConcurrentRunCount int64 `yaml:"maxConcurrentRunCount"` -} - -func (vaipc VAIProviderConfig) VaiEndpoint() string { - return fmt.Sprintf("%s-aiplatform.googleapis.com:443", vaipc.Parameters.VaiLocation) -} - -func (vaipc VAIProviderConfig) parent() string { - return fmt.Sprintf(`projects/%s/locations/%s`, vaipc.Parameters.VaiProject, vaipc.Parameters.VaiLocation) -} - -func (vaipc VAIProviderConfig) pipelineJobName(name string) string { - return fmt.Sprintf("%s/pipelineJobs/%s", vaipc.parent(), name) -} - -func (vaipc VAIProviderConfig) pipelineStorageObject(pipelineName common.NamespacedName, pipelineVersion string) (string, error) { - namespaceName, err := pipelineName.String() - if err != nil { - return "", err - } - return fmt.Sprintf("%s/%s", namespaceName, pipelineVersion), nil -} - -func (vaipc VAIProviderConfig) pipelineUri(pipelineName common.NamespacedName, pipelineVersion string) (string, error) { - pipelineUri, err := vaipc.pipelineStorageObject(pipelineName, pipelineVersion) - if err != nil { - return "", err - } - return fmt.Sprintf("gs://%s/%s", vaipc.Parameters.PipelineBucket, pipelineUri), nil -} - -func (vaipc VAIProviderConfig) getMaxConcurrentRunCountOrDefault() int64 { - const defaultMaxConcurrentRunCount = 10 - if vaipc.Parameters.MaxConcurrentRunCount <= 0 { - return defaultMaxConcurrentRunCount - } else { - return vaipc.Parameters.MaxConcurrentRunCount - } -} diff --git a/provider-service/vai/internal/config/config.go b/provider-service/vai/internal/config/config.go new file mode 100644 index 00000000..107cca21 --- /dev/null +++ b/provider-service/vai/internal/config/config.go @@ -0,0 +1,71 @@ +package config + +import ( + "context" + "fmt" + "github.com/sky-uk/kfp-operator/argo/common" + "strings" + + "github.com/spf13/viper" +) + +// temporarily copied over here for vai specific provider config +type Config struct { + ProviderName string `mapstructure:"providerName"` + OperatorWebhook string `mapstructure:"operatorWebhook"` + Pod Pod `mapstructure:"pod"` + Server Server `mapstructure:"server"` + Parameters Parameters `mapstructure:"parameters"` +} + +type Pod struct { + Namespace string `mapstructure:"namespace"` +} + +type Server struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` +} + +type Parameters struct { + VaiProject string `mapstructure:"vaiProject"` + VaiLocation string `mapstructure:"vaiLocation"` + VaiJobServiceAccount string `mapstructure:"vaiJobServiceAccount"` + GcsEndpoint string `mapstructure:"gcsEndpoint"` + PipelineBucket string `mapstructure:"pipelineBucket"` + EventsourcePipelineEventsSubscription string `mapstructure:"eventsourcePipelineEventsSubscription"` + MaxConcurrentRunCount int64 `mapstructure:"maxConcurrentRunCount"` +} + +func LoadConfig(ctx context.Context) (*Config, error) { + logger := common.LoggerFromContext(ctx) + config, err := load() + + if err != nil { + logger.Error(err, "failed to load config file") + return nil, err + } + + logger.Info(fmt.Sprintf("loaded config: %+v", config)) + return config, nil +} + +func load() (*Config, error) { + viper.SetConfigName("config") + viper.AddConfigPath("/etc/provider-service") + viper.AddConfigPath(".") + + if err := viper.ReadInConfig(); err != nil { + return nil, fmt.Errorf("fatal error loading config %w", err) + } + + viper.AutomaticEnv() + viper.SetEnvKeyReplacer(strings.NewReplacer(`.`, `_`)) + + var config Config + if err := viper.Unmarshal(&config); err != nil { + return nil, fmt.Errorf("fatal error unmarshalling config %w", err) + } + + return &config, nil +}