diff --git a/CHANGELOG.md b/CHANGELOG.md index b08240862..25107fd92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Bugfixes * Update Go runtime to 1.21.3. [#1102](https://github.com/elastic/package-registry/pull/1102) +* Raise an error if the value of environment variables used to set parameters are not valid [#1103](https://github.com/elastic/package-registry/pull/1103) ### Added +* Add new parameter to specify minimum TLS version [#1103](https://github.com/elastic/package-registry/pull/1103) + ### Deprecated ### Known Issues diff --git a/flags.go b/flags.go index 2571700eb..802fe00bc 100644 --- a/flags.go +++ b/flags.go @@ -5,29 +5,72 @@ package main import ( + "crypto/tls" + "errors" "flag" + "fmt" "os" "strings" ) -func parseFlags() { - parseFlagSetWithArgs(flag.CommandLine, os.Args) +var supportedTLSVersions map[string]uint16 = map[string]uint16{ + "1.0": tls.VersionTLS10, + "1.1": tls.VersionTLS11, + "1.2": tls.VersionTLS12, + "1.3": tls.VersionTLS13, } -func parseFlagSetWithArgs(flagSet *flag.FlagSet, args []string) { - flagsFromEnv(flagSet) +type tlsVersionValue uint16 + +func (t tlsVersionValue) String() string { + switch t { + case tls.VersionTLS10: + return "1.0" + case tls.VersionTLS11: + return "1.1" + case tls.VersionTLS12: + return "1.2" + case tls.VersionTLS13: + return "1.3" + default: + return "" + } +} + +func (t *tlsVersionValue) Set(s string) error { + if _, ok := supportedTLSVersions[s]; !ok { + return fmt.Errorf("unsupported TLS version: %s", s) + } + *t = tlsVersionValue(supportedTLSVersions[s]) + return nil +} + +func parseFlags() error { + return parseFlagSetWithArgs(flag.CommandLine, os.Args) +} + +func parseFlagSetWithArgs(flagSet *flag.FlagSet, args []string) error { + err := flagsFromEnv(flagSet) + if err != nil { + return err + } // Skip args[0] as flag.Parse() does. flagSet.Parse(args[1:]) + return nil } -func flagsFromEnv(flagSet *flag.FlagSet) { +func flagsFromEnv(flagSet *flag.FlagSet) error { + var flagErrors error flagSet.VisitAll(func(f *flag.Flag) { envName := flagEnvName(f.Name) if value, found := os.LookupEnv(envName); found { - f.Value.Set(value) + if err := f.Value.Set(value); err != nil { + flagErrors = errors.Join(flagErrors, fmt.Errorf("failed to set -%s: %v", f.Name, err)) + } } }) + return flagErrors } const flagEnvPrefix = "EPR_" diff --git a/flags_test.go b/flags_test.go index dcb9cde6c..d91240124 100644 --- a/flags_test.go +++ b/flags_test.go @@ -36,7 +36,8 @@ func TestFlagsPrecedence(t *testing.T) { require.Equal(t, "default", dummyFlag) args := []string{"test", "-test-precedence-dummy=" + expected} - parseFlagSetWithArgs(flagSet, args) + err := parseFlagSetWithArgs(flagSet, args) + require.NoError(t, err) require.Equal(t, expected, dummyFlag) } diff --git a/main.go b/main.go index ad73f1d84..9569b3563 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ package main import ( "context" + "crypto/tls" "flag" "fmt" "log" @@ -53,6 +54,8 @@ var ( tlsCertFile string tlsKeyFile string + tlsMinVersionValue tlsVersionValue + dryRun bool configPath string @@ -83,6 +86,7 @@ func init() { flag.StringVar(&logType, "log-type", util.DefaultLoggerType, "log type (ecs, dev)") flag.StringVar(&tlsCertFile, "tls-cert", "", "Path of the TLS certificate.") flag.StringVar(&tlsKeyFile, "tls-key", "", "Path of the TLS key.") + flag.Var(&tlsMinVersionValue, "tls-min-version", "Minimum version TLS supported.") flag.StringVar(&configPath, "config", "config.yml", "Path to the configuration file.") flag.StringVar(&httpProfAddress, "httpprof", "", "Enable HTTP profiler listening on the given address.") // This flag is experimental and might be removed in the future or renamed @@ -108,7 +112,16 @@ type Config struct { } func main() { - parseFlags() + err := parseFlags() + if err != nil { + log.Fatal(err) + } + + if tlsMinVersionValue > 0 { + if tlsCertFile == "" || tlsKeyFile == "" { + log.Fatalf("-tls-min-version set but missing TLS cert and key files (-tls-cert and -tls-key)") + } + } if printVersionInfo { fmt.Printf("Elastic Package Registry version %v\n", version) @@ -241,7 +254,11 @@ func initServer(logger *zap.Logger, apmTracer *apm.Tracer, config *Config) *http router := mustLoadRouter(logger, config, indexer) apmgorilla.Instrument(router, apmgorilla.WithTracer(apmTracer)) - return &http.Server{Addr: address, Handler: router} + var tlsConfig tls.Config + if tlsMinVersionValue > 0 { + tlsConfig.MinVersion = uint16(tlsMinVersionValue) + } + return &http.Server{Addr: address, Handler: router, TLSConfig: &tlsConfig} } func runServer(server *http.Server) error {