diff --git a/lib/service/service.go b/lib/service/service.go index 3de8967dd71f3..53492052aa26e 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4583,6 +4583,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { TracerProvider: process.TracingProvider, AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels, IntegrationAppHandler: connectionsHandler, + FeatureWatchInterval: utils.HalfJitter(web.DefaultFeatureWatchInterval * 2), } webHandler, err := web.NewHandler(webConfig) if err != nil { diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 771adabe1f3f3..76a440143b977 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -21,6 +21,7 @@ package web import ( + "cmp" "compress/gzip" "context" "encoding/base64" @@ -126,6 +127,9 @@ const ( // Example values: // - github-actions-ssh: indicates that the resource was added via the Bot GitHub Actions SSH flow webUIFlowLabelKey = "teleport.internal/ui-flow" + // DefaultFeatureWatchInterval is the default time in which the feature watcher + // should ping the auth server to check for updated features + DefaultFeatureWatchInterval = time.Minute * 5 ) // healthCheckAppServerFunc defines a function used to perform a health check @@ -162,10 +166,6 @@ type Handler struct { userConns atomic.Int32 // ClusterFeatures contain flags for supported and unsupported features. - // Note: This field can become stale since it's only set on initial proxy - // startup. To get the latest feature flags you'll need to ping from the - // auth server. - // https://github.com/gravitational/teleport/issues/39161 ClusterFeatures proto.Features // nodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from @@ -325,6 +325,10 @@ type Config struct { // IntegrationAppHandler handles App Access requests which use an Integration. IntegrationAppHandler app.ServerHandler + + // FeatureWatchInterval is the interval between pings to the auth server + // to fetch new cluster features + FeatureWatchInterval time.Duration } // SetDefaults ensures proper default values are set if @@ -339,6 +343,8 @@ func (c *Config) SetDefaults() { if c.PresenceChecker == nil { c.PresenceChecker = client.RunPresenceTask } + + c.FeatureWatchInterval = cmp.Or(c.FeatureWatchInterval, DefaultFeatureWatchInterval) } type APIHandler struct { @@ -669,6 +675,8 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { } } + go h.startFeatureWatcher() + return &APIHandler{ handler: h, appHandler: appHandler, @@ -1692,14 +1700,7 @@ func (h *Handler) getWebConfig(w http.ResponseWriter, r *http.Request, p httprou } } - clusterFeatures := h.ClusterFeatures - // ping server to get cluster features since h.ClusterFeatures may be stale - pingResponse, err := h.GetProxyClient().Ping(r.Context()) - if err != nil { - h.log.WithError(err).Warn("Cannot retrieve cluster features, client may receive stale features") - } else { - clusterFeatures = *pingResponse.ServerFeatures - } + clusterFeatures := h.GetClusterFeatures() // get tunnel address to display on cloud instances tunnelPublicAddr := "" diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 854695f648866..dc7aedbe97d25 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -4592,6 +4592,7 @@ func TestGetWebConfig(t *testing.T) { env := newWebPack(t, 1, func(cfg *proxyConfig) { cfg.serviceConfig = svcConfig }) + handler := env.proxies[0].handler.handler // Set auth preference with passwordless. const MOTD = "Welcome to cluster, your activity will be recorded." @@ -4622,6 +4623,9 @@ func TestGetWebConfig(t *testing.T) { _, err = env.server.Auth().UpsertGithubConnector(ctx, github) require.NoError(t, err) + // start the feature watcher so the web config gets new features + env.clock.Advance(DefaultFeatureWatchInterval * 2) + expectedCfg := webclient.WebConfig{ Auth: webclient.WebConfigAuthSettings{ SecondFactor: constants.SecondFactorOptional, @@ -4669,6 +4673,7 @@ func TestGetWebConfig(t *testing.T) { AutomaticUpgrades: true, }, }) + env.clock.Advance(DefaultFeatureWatchInterval * 2) svcConfig.Proxy.AssistAPIKey = "test" require.NoError(t, err) @@ -4680,7 +4685,7 @@ func TestGetWebConfig(t *testing.T) { }, } require.NoError(t, channels.CheckAndSetDefaults()) - env.proxies[0].handler.handler.cfg.AutomaticUpgradesChannels = channels + handler.cfg.AutomaticUpgradesChannels = channels expectedCfg.IsCloud = true expectedCfg.IsUsageBasedBilling = true @@ -4689,14 +4694,20 @@ func TestGetWebConfig(t *testing.T) { expectedCfg.AssistEnabled = false expectedCfg.JoinActiveSessions = false - // request and verify enabled features are enabled. - re, err = clt.Get(ctx, endpoint, nil) - require.NoError(t, err) - require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) - str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") - err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) - require.NoError(t, err) - require.Equal(t, expectedCfg, cfg) + // request and verify enabled features are eventually enabled. + require.EventuallyWithT(t, func(t *assert.CollectT) { + re, err := clt.Get(ctx, endpoint, nil) + if !assert.NoError(t, err) { + return + } + assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG"))) + res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{}) + err = json.Unmarshal(res[:len(res)-1], &cfg) + assert.NoError(t, err) + diff := cmp.Diff(expectedCfg, cfg) + assert.Empty(t, diff) + + }, time.Second*5, time.Millisecond*50) // use mock client to assert that if ping returns an error, we'll default to // cluster config @@ -4715,15 +4726,22 @@ func TestGetWebConfig(t *testing.T) { IsUsageBasedBilling: false, }, }) + env.clock.Advance(DefaultFeatureWatchInterval * 2) // request and verify again - re, err = clt.Get(ctx, endpoint, nil) - require.NoError(t, err) - require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) - str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") - err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) - require.NoError(t, err) - require.Equal(t, expectedCfg, cfg) + require.EventuallyWithT(t, func(t *assert.CollectT) { + re, err := clt.Get(ctx, endpoint, nil) + if !assert.NoError(t, err) { + return + } + assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG"))) + res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{}) + err = json.Unmarshal(res[:len(res)-1], &cfg) + assert.NoError(t, err) + diff := cmp.Diff(expectedCfg, cfg) + assert.Empty(t, diff) + + }, time.Second*5, time.Millisecond*50) } func TestGetWebConfig_IGSFeatureLimits(t *testing.T) { @@ -4745,6 +4763,8 @@ func TestGetWebConfig_IGSFeatureLimits(t *testing.T) { Questionnaire: true, }, }) + // start the feature watcher so the web config gets new features + env.clock.Advance(DefaultFeatureWatchInterval * 2) expectedCfg := webclient.WebConfig{ Auth: webclient.WebConfigAuthSettings{ @@ -4766,20 +4786,25 @@ func TestGetWebConfig_IGSFeatureLimits(t *testing.T) { IsUsageBasedBilling: true, } - // Make a request. clt := env.proxies[0].newClient(t) - endpoint := clt.Endpoint("web", "config.js") - re, err := clt.Get(ctx, endpoint, nil) - require.NoError(t, err) - require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) - - // Response is type application/javascript, we need to strip off the variable name - // and the semicolon at the end, then we are left with json like object. - var cfg webclient.WebConfig - str := strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") - err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) - require.NoError(t, err) - require.Equal(t, expectedCfg, cfg) + require.EventuallyWithT(t, func(t *assert.CollectT) { + // Make a request. + endpoint := clt.Endpoint("web", "config.js") + re, err := clt.Get(ctx, endpoint, nil) + if !assert.NoError(t, err) { + return + } + assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG"))) + + // Response is type application/javascript, we need to strip off the variable name + // and the semicolon at the end, then we are left with json like object. + var cfg webclient.WebConfig + res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{}) + err = json.Unmarshal(res[:len(res)-1], &cfg) + assert.NoError(t, err) + diff := cmp.Diff(expectedCfg, cfg) + assert.Empty(t, diff) + }, time.Second*5, time.Millisecond*50) } func TestCreatePrivilegeToken(t *testing.T) { diff --git a/lib/web/features.go b/lib/web/features.go new file mode 100644 index 0000000000000..b387d0f176958 --- /dev/null +++ b/lib/web/features.go @@ -0,0 +1,71 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package web + +import ( + "github.com/gravitational/teleport/api/client/proto" +) + +// SetClusterFeatures sets the flags for supported and unsupported features. +// TODO(mcbattirola): make method unexported, fix tests using it to set +// test modules instead. +func (h *Handler) SetClusterFeatures(features proto.Features) { + h.Mutex.Lock() + defer h.Mutex.Unlock() + + h.ClusterFeatures = features +} + +// GetClusterFeatures returns flags for supported and unsupported features. +func (h *Handler) GetClusterFeatures() proto.Features { + h.Mutex.Lock() + defer h.Mutex.Unlock() + + return h.ClusterFeatures +} + +// startFeatureWatcher periodically pings the auth server and updates `clusterFeatures`. +// Must be called only once per `handler`, otherwise it may close an already closed channel +// which will cause a panic. +// The watcher doesn't ping the auth server immediately upon start because features are +// already set by the config object in `NewHandler`. +func (h *Handler) startFeatureWatcher() { + ticker := h.clock.NewTicker(h.cfg.FeatureWatchInterval) + h.log.WithField("interval", h.cfg.FeatureWatchInterval).Info("Proxy handler features watcher has started") + ctx := h.cfg.Context + + defer ticker.Stop() + for { + select { + case <-ticker.Chan(): + h.log.Info("Pinging auth server for features") + pingResponse, err := h.GetProxyClient().Ping(ctx) + if err != nil { + h.log.WithError(err).Error("Auth server ping failed") + continue + } + + h.SetClusterFeatures(*pingResponse.ServerFeatures) + h.log.WithField("features", pingResponse.ServerFeatures).Info("Done updating proxy features") + case <-ctx.Done(): + h.log.Info("Feature service has stopped") + return + } + } +} diff --git a/lib/web/features_test.go b/lib/web/features_test.go new file mode 100644 index 0000000000000..9e167f58324ca --- /dev/null +++ b/lib/web/features_test.go @@ -0,0 +1,160 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package web + +import ( + "context" + "log/slog" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/auth/authclient" +) + +// mockedPingTestProxy is a test proxy with a mocked Ping method +// that returns the internal features +type mockedFeatureGetter struct { + authclient.ClientI + features proto.Features +} + +func (m *mockedFeatureGetter) Ping(ctx context.Context) (proto.PingResponse, error) { + return proto.PingResponse{ + ServerFeatures: utils.CloneProtoMsg(&m.features), + }, nil +} + +func (m *mockedFeatureGetter) setFeatures(f proto.Features) { + m.features = f +} + +func TestFeaturesWatcher(t *testing.T) { + clock := clockwork.NewFakeClock() + + mockClient := &mockedFeatureGetter{features: proto.Features{ + Kubernetes: true, + AccessRequests: &proto.AccessRequestsFeature{}, + }} + + ctx, cancel := context.WithCancel(context.Background()) + handler := &Handler{ + cfg: Config{ + FeatureWatchInterval: 100 * time.Millisecond, + ProxyClient: mockClient, + Context: ctx, + }, + clock: clock, + ClusterFeatures: proto.Features{}, + log: newPackageLogger(), + logger: slog.Default().With(teleport.ComponentKey, teleport.ComponentWeb), + } + + // before running the watcher, features should match the value passed to the handler + requireFeatures(t, clock, proto.Features{}, handler.GetClusterFeatures) + + go handler.startFeatureWatcher() + clock.BlockUntil(1) + + // after starting the watcher, handler.GetClusterFeatures should return + // values matching the client's response + features := proto.Features{ + Kubernetes: true, + AccessRequests: &proto.AccessRequestsFeature{}, + } + expected := utils.CloneProtoMsg(&features) + requireFeatures(t, clock, *expected, handler.GetClusterFeatures) + + // update values once again and check if the features are properly updated + features = proto.Features{ + Kubernetes: false, + AccessRequests: &proto.AccessRequestsFeature{}, + } + mockClient.setFeatures(features) + expected = utils.CloneProtoMsg(&features) + requireFeatures(t, clock, *expected, handler.GetClusterFeatures) + + // test updating features + features = proto.Features{ + Kubernetes: true, + ExternalAuditStorage: true, + AccessList: &proto.AccessListFeature{CreateLimit: 10}, + AccessMonitoring: &proto.AccessMonitoringFeature{}, + App: true, + AccessRequests: &proto.AccessRequestsFeature{}, + } + mockClient.setFeatures(features) + + expected = &proto.Features{ + Kubernetes: true, + ExternalAuditStorage: true, + AccessList: &proto.AccessListFeature{CreateLimit: 10}, + AccessMonitoring: &proto.AccessMonitoringFeature{}, + App: true, + AccessRequests: &proto.AccessRequestsFeature{}, + } + requireFeatures(t, clock, *expected, handler.GetClusterFeatures) + + // stop watcher and ensure it stops updating features + cancel() + features = proto.Features{ + Kubernetes: !features.Kubernetes, + App: !features.App, + DB: true, + AccessRequests: &proto.AccessRequestsFeature{}, + } + mockClient.setFeatures(features) + notExpected := utils.CloneProtoMsg(&features) + // assert the handler never get these last features as the watcher is stopped + neverFeatures(t, clock, *notExpected, handler.GetClusterFeatures) +} + +// requireFeatures is a helper function that advances the clock, then +// calls `getFeatures` every 100ms for up to 1 second, until it +// returns the expected result (`want`). +func requireFeatures(t *testing.T, fakeClock clockwork.FakeClock, want proto.Features, getFeatures func() proto.Features) { + t.Helper() + + // Advance the clock so the service fetch and stores features + fakeClock.Advance(1 * time.Second) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + diff := cmp.Diff(want, getFeatures()) + assert.Empty(t, diff) + }, 5*time.Second, time.Millisecond*100) +} + +// neverFeatures is a helper function that advances the clock, then +// calls `getFeatures` every 100ms for up to 1 second. If at some point `getFeatures` +// returns `doNotWant`, the test fails. +func neverFeatures(t *testing.T, fakeClock clockwork.FakeClock, doNotWant proto.Features, getFeatures func() proto.Features) { + t.Helper() + + fakeClock.Advance(1 * time.Second) + require.Never(t, func() bool { + return cmp.Diff(doNotWant, getFeatures()) == "" + }, 1*time.Second, time.Millisecond*100) +} diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 1819a66059ece..631643f1fd7d4 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -125,7 +125,7 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p } teleportVersionTag := teleport.Version - if automaticUpgrades(h.ClusterFeatures) { + if automaticUpgrades(h.GetClusterFeatures()) { cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) if err != nil { return "", trace.Wrap(err) @@ -178,7 +178,7 @@ func (h *Handler) awsOIDCDeployDatabaseServices(w http.ResponseWriter, r *http.R } teleportVersionTag := teleport.Version - if automaticUpgrades(h.ClusterFeatures) { + if automaticUpgrades(h.GetClusterFeatures()) { cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) if err != nil { return "", trace.Wrap(err) @@ -480,7 +480,7 @@ func (h *Handler) awsOIDCEnrollEKSClusters(w http.ResponseWriter, r *http.Reques return nil, trace.BadParameter("an integration name is required") } - agentVersion, err := kubeutils.GetKubeAgentVersion(ctx, h.cfg.ProxyClient, h.ClusterFeatures, h.cfg.AutomaticUpgradesChannels) + agentVersion, err := kubeutils.GetKubeAgentVersion(ctx, h.cfg.ProxyClient, h.GetClusterFeatures(), h.cfg.AutomaticUpgradesChannels) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index 9b5bfd459aaf3..31338c34837e3 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -216,7 +216,7 @@ func (h *Handler) createTokenHandle(w http.ResponseWriter, r *http.Request, para func (h *Handler) getAutoUpgrades(ctx context.Context) (bool, string, error) { var autoUpgradesVersion string var err error - autoUpgrades := automaticUpgrades(h.ClusterFeatures) + autoUpgrades := automaticUpgrades(h.GetClusterFeatures()) if autoUpgrades { autoUpgradesVersion, err = h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) if err != nil {